File size: 8,384 Bytes
6da2a44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import datetime
import math
import sys
import time
import logging
import os
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
import dist
from models.encoder import SparseEncoder
from models.decoder import LightDecoder
from models.MambaMIM import MambaMIM
from models import build_sparse_encoder
from utils.sampler import DistInfiniteBatchSampler, worker_init_fn
from utils import arg_util, misc
from utils.med_dataset import get_loader
from utils.lr_control import lr_wd_annealing
cpu_num = 1
os.environ['OMP_NUM_THREADS'] = str(cpu_num)
os.environ['OPENBLAS_NUM_THREADS'] = str(cpu_num)
os.environ['MKL_NUM_THREADS'] = str(cpu_num)
os.environ['VECLIB_MAXIMUM_THREADS'] = str(cpu_num)
os.environ['NUMEXPR_NUM_THREADS'] = str(cpu_num)
torch.set_num_threads(cpu_num)
torch.multiprocessing.set_sharing_strategy('file_system')
class LocalDDP(torch.nn.Module):
def __init__(self, module):
super(LocalDDP, self).__init__()
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def main_pt():
args: arg_util.Args = arg_util.init_dist_and_get_args()
print(f'initial args:\n{str(args)}')
args.log_epoch()
# build data
print(f'[build data for pre-training] ...\n')
dataset_train = get_loader(args.data_path, args.input_size)
data_loader_train = DataLoader(
dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True,
batch_sampler=DistInfiniteBatchSampler(
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size,
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
), worker_init_fn=worker_init_fn
)
itrt_train, iters_train = iter(data_loader_train), len(data_loader_train)
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size_per_gpu}, iters_train={iters_train}')
# build encoder and decoder
enc: SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False)
dec = LightDecoder(enc.downsample_raito, sbn=args.sbn)
model_without_ddp = MambaMIM(
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask,
densify_norm=args.densify_norm, sbn=args.sbn,
).to(args.device)
print(f'[PT model] model = {model_without_ddp}\n')
# the model has been randomly initialized in their construction time
# now try to load some checkpoint as model weight initialization; this ONLY loads the model weights
model = LocalDDP(model_without_ddp)
# build optimizer and lr_scheduler
optimizer = torch.optim.AdamW(params=model_without_ddp.parameters(), lr=args.lr, weight_decay=1e-5)
# try to resume the experiment from some checkpoint.pth; this will load model weights, optimizer states, and last epoch (ep_start)
# if loaded, ep_start will be greater than 0
ep_start, performance_desc = misc.load_checkpoint(args.resume_from, model_without_ddp, optimizer)
if ep_start >= args.ep: # load from a complete checkpoint file
print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}')
else: # perform pre-training
tb_lg = misc.TensorboardLogger(args.tb_lg_dir, is_master=dist.is_master(), prefix='pt')
min_loss = 1e9
print(f'[PT start] from ep{ep_start}')
pt_start_time = time.time()
for ep in range(ep_start, args.ep):
ep_start_time = time.time()
tb_lg.set_step(ep * iters_train)
if hasattr(itrt_train, 'set_epoch'):
itrt_train.set_epoch(ep)
stats = pre_train_one_ep(ep, args, tb_lg, itrt_train, iters_train, model, optimizer)
last_loss = stats['last_loss']
min_loss = min(min_loss, last_loss)
performance_desc = f'{min_loss:.4f} {last_loss:.4f}'
misc.save_checkpoint_with_meta_info_and_opt_state(f'{args.model}_withdecoder_ct_pretrained.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
misc.save_checkpoint_model_weights_only(f'{args.model}_ct_pretrained_mambamim_timm_style.pth', args, model_without_ddp.sparse_encoder.state_dict())
ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost
remain_secs = (args.ep-1 - ep) * ep_cost
remain_time = datetime.timedelta(seconds=round(remain_secs))
finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs))
print(f' [*] [ep{ep}/{args.ep}] Min/Last Recon Loss: {performance_desc}, Cost: {ep_cost}s, Remain: {remain_time}, Finish @ {finish_time}')
args.cur_ep = f'{ep + 1}/{args.ep}'
args.remain_time, args.finish_time = str(remain_time), str(finish_time)
args.last_loss = last_loss
args.log_epoch()
tb_lg.update(min_loss=min_loss, head='train', step=ep)
tb_lg.update(rest_hours=round(remain_secs/60/60, 2), head='z_burnout', step=ep)
tb_lg.flush()
# finish pre-training
tb_lg.update(min_loss=min_loss, head='result', step=ep_start)
tb_lg.update(min_loss=min_loss, head='result', step=args.ep)
tb_lg.flush()
print(f'final args:\n{str(args)}')
print('\n\n')
print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - pt_start_time) / 60 / 60:.1f}h\n')
print('\n\n')
tb_lg.close()
time.sleep(10)
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
args.log_epoch()
def pre_train_one_ep(ep, args: arg_util.Args, tb_lg: misc.TensorboardLogger, itrt_train, iters_train, model: DistributedDataParallel, optimizer):
model.train()
me = misc.MetricLogger(delimiter=' ')
me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}'))
header = f'[PT] Epoch {ep}:'
optimizer.zero_grad()
early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm')
late_clipping = hasattr(optimizer, 'global_grad_norm')
if early_clipping:
params_req_grad = [p for p in model.parameters() if p.requires_grad]
for it, inp in enumerate(me.log_every(iters_train, itrt_train, 3, header)):
# adjust lr and wd
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, it + ep * iters_train, args.wp_ep * iters_train, args.ep * iters_train)
# forward and backward
# print(inp)
temp = []
for crop_per_batch in inp:
temp.append(crop_per_batch["image"])
inp = torch.cat(temp, dim=0)
inp = inp.to(args.device, non_blocking=True)
MambaSparK.forward
loss = model(inp, active_b1fff=None, vis=False)
optimizer.zero_grad()
loss.backward()
loss = loss.item()
if not math.isfinite(loss):
print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True)
sys.exit(-1)
# optimize
grad_norm = None
if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item()
optimizer.step()
if late_clipping: grad_norm = optimizer.global_grad_norm
torch.cuda.synchronize()
# log
me.update(last_loss=loss)
me.update(max_lr=max_lr)
tb_lg.update(loss=me.meters['last_loss'].global_avg, head='train_loss')
tb_lg.update(sche_lr=max_lr, head='train_hp/lr_max')
tb_lg.update(sche_lr=min_lr, head='train_hp/lr_min')
tb_lg.update(sche_wd=max_wd, head='train_hp/wd_max')
tb_lg.update(sche_wd=min_wd, head='train_hp/wd_min')
if grad_norm is not None:
me.update(orig_norm=grad_norm)
tb_lg.update(orig_norm=grad_norm, head='train_hp')
tb_lg.set_step()
me.synchronize_between_processes()
return {k: meter.global_avg for k, meter in me.meters.items()}
if __name__ == '__main__':
main_pt()
|