|
|
|
|
|
|
|
__all__ = ['SimpleVisual', 'validate', 'train'] |
|
|
|
|
|
import io |
|
import time |
|
import random |
|
from pathlib import Path |
|
|
|
from fastprogress import progress_bar, master_bar |
|
import fastprogress |
|
|
|
import numpy as np |
|
import pylab as plt |
|
import math |
|
|
|
import IPython |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data.dataloader import DataLoader |
|
from torch.profiler import record_function |
|
|
|
import webdataset as wds |
|
|
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.enabled = True |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.set_float32_matmul_precision('medium') |
|
|
|
|
|
class SimpleVisual: |
|
def __init__ (self, model, masterbar, total_steps): |
|
self.model = model |
|
self.masterbar = masterbar |
|
self.total_steps = total_steps |
|
self.epochs = total_steps // masterbar.main_bar.total |
|
|
|
gs = plt.GridSpec(2, 1, height_ratios=[3,1]) |
|
graph_fig = plt.figure(figsize=(10,6)) |
|
self.graph_fig = graph_fig |
|
self.loss_p = graph_fig.add_subplot(gs[0]) |
|
self.lr_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p) |
|
self.lr_p.tick_params('x', labelbottom=False) |
|
self.graph_out = None |
|
|
|
self.its = [] |
|
self.train_losses = [] |
|
self.val_losses = [] |
|
self.lr_history = [] |
|
|
|
def show(self): |
|
self.start_t = time.time() |
|
self.masterbar.write(["samples", "train", "val", "time"], table=True) |
|
self.graph_out = display(self.graph_fig, display_id=True, clear=True) |
|
|
|
def hide(self): |
|
if self.graph_out is not None: |
|
self.graph_out.update(IPython.display.HTML('')) |
|
|
|
def plot(self): |
|
loss_p, lr_p = self.loss_p, self.lr_p |
|
loss_p.clear() |
|
loss_p.plot(self.its, self.train_losses) |
|
loss_p.plot(self.its, self.val_losses) |
|
loss_p.set_xlim(0, self.total_steps) |
|
loss_p.set_yscale('log') |
|
lr_p.clear() |
|
lrs = np.array(self.lr_history) |
|
lr_p.plot(self.its, lrs) |
|
self.graph_out.update(self.graph_fig) |
|
|
|
def add_data(self, it, lr, train_loss, val_los): |
|
self.its.append(it) |
|
self.train_losses.append(train_loss) |
|
self.val_losses.append(val_los) |
|
self.lr_history.append(lr) |
|
self.plot() |
|
|
|
def add_table_row(self, it, avg_train_loss, val_loss): |
|
elapsed_t = time.time() - self.start_t |
|
self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True) |
|
|
|
def on_iter(self, bar, it, avg_train_loss, val_loss): |
|
epoch = math.ceil(it / self.total_steps * self.epochs) |
|
bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}" |
|
|
|
|
|
|
|
def validate(model, val, half=True, bs=16, drop_last=False, dl_workers=8, device="cuda"): |
|
if isinstance(val, torch.utils.data.IterableDataset): |
|
val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \ |
|
.unbatched().shuffle(1024).batched(bs) |
|
else: |
|
val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last) |
|
|
|
with torch.no_grad(): |
|
val_loss = 0 |
|
val_samples = 0 |
|
for args in val_loader: |
|
args = [x.to(device, non_blocking=True) for x in args] |
|
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'): |
|
ps, loss = model(*args) |
|
N = args[0].shape[0] |
|
val_loss += loss.mean().item() * N |
|
val_samples += N |
|
val_loss = val_loss / val_samples |
|
|
|
return val_loss |
|
|
|
|
|
def train(checkpoint_path, model, train, val, half=True, bs=16, lr=1e-4, drop_last=False, |
|
weight_decay=0.1, warmup_steps=10000, epochs=10, clip_gradient_norm=None, |
|
dl_workers=8, visual_class = SimpleVisual, profiler=None, |
|
run_valid_every_iters=8000, table_row_every_iters=80000, chkpt_every_iters=None, |
|
device="cuda", trainable_params=None): |
|
if chkpt_every_iters is None: |
|
chkpt_every_iters = table_row_every_iters |
|
|
|
mb = master_bar(range(epochs)) |
|
if isinstance(train, torch.utils.data.IterableDataset): |
|
pct_start = min(0.3, warmup_steps / (epochs * (train.total_samples//bs))) |
|
visual = visual_class(model, mb, epochs * train.total_samples) |
|
|
|
|
|
else: |
|
pct_start = min(0.3, warmup_steps / (epochs * len(train) / bs)) |
|
visual = visual_class(model, mb, epochs*len(train)) |
|
model.visual = visual |
|
|
|
Path(checkpoint_path).mkdir(exist_ok=True) |
|
|
|
if isinstance(train, torch.utils.data.IterableDataset): |
|
|
|
|
|
train_loader = wds.WebLoader(train, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \ |
|
.unbatched().shuffle(1024).batched(bs, partial=False) |
|
val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \ |
|
.unbatched().shuffle(1024).batched(bs) |
|
else: |
|
train_loader = DataLoader(train, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last, shuffle=True) |
|
val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last) |
|
|
|
val_loss = torch.nan |
|
avg_train_loss = torch.nan |
|
|
|
if hasattr(model, 'setup'): |
|
model.setup(device) |
|
|
|
try: |
|
scheduler = None |
|
|
|
if trainable_params is None: trainable_params = model.parameters() |
|
all_params = set(trainable_params) |
|
customized_params = set() |
|
groups = [] |
|
group_map = {} |
|
for name,m in model.named_modules(): |
|
if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'): |
|
m_trainable = [x for x in m.parameters() if x in all_params] |
|
if not m_trainable: continue |
|
customized_params |= set(m_trainable) |
|
m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay |
|
m_lr = lr * getattr(m, 'lr_scale', 1) |
|
group = group_map.get((m_wd, m_lr), None) |
|
if not group: |
|
group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr} |
|
groups.append(group) |
|
group_map[(m_wd, m_lr)] = group |
|
group['params'] += m_trainable |
|
group['names'].append(name) |
|
|
|
other_params = all_params - customized_params |
|
|
|
if other_params: |
|
groups = groups + [ |
|
{"names": ["other"], "params": list(other_params), "weight_decay": weight_decay }, |
|
] |
|
|
|
optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), fused=device!='cpu', params=groups) |
|
model._optimizer = optimizer |
|
scaler = torch.cuda.amp.GradScaler(enabled=half) |
|
scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
optimizer, pct_start=pct_start, steps_per_epoch=math.ceil(train.total_samples/bs), epochs=epochs, |
|
max_lr=[pg.get('lr', lr) for pg in groups], |
|
final_div_factor=25) |
|
|
|
it = 0 |
|
next_val_it = it + 50 |
|
next_chkpt_it = chkpt_every_iters |
|
next_table_it = table_row_every_iters |
|
|
|
visual.show() |
|
|
|
running_loss = [0] |
|
|
|
for epoch in mb: |
|
bar = progress_bar(train_loader, total=train.total_samples//bs, parent=mb) |
|
for args in bar: |
|
with record_function("forward"): |
|
args = [x.to(device, non_blocking=True) for x in args] |
|
|
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'): |
|
ps, loss = model(*args) |
|
loss = loss.mean() |
|
|
|
with record_function("backward"): |
|
scaler.scale(loss).backward() |
|
|
|
if clip_gradient_norm: |
|
scaler.unscale_(optimizer) |
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_gradient_norm) |
|
|
|
scaler.step(optimizer) |
|
scaler.update() |
|
|
|
scheduler.step() |
|
|
|
if profiler is not None: profiler.step() |
|
|
|
with record_function("running_loss"): |
|
running_loss.append(loss.item()) |
|
running_loss = running_loss[-5:] |
|
avg_train_loss = sum(running_loss)/len(running_loss) |
|
|
|
if it >= next_chkpt_it: |
|
with record_function("checkpoint"): |
|
next_chkpt_it += chkpt_every_iters |
|
torch.save(model.state_dict(), f'{checkpoint_path}/{it:08d}.pt') |
|
|
|
if it >= next_val_it: |
|
next_val_it += run_valid_every_iters |
|
with record_function("validation"): |
|
with record_function("model.eval"): |
|
model.eval() |
|
with torch.no_grad(): |
|
val_loss = 0 |
|
val_samples = 0 |
|
for args in val_loader: |
|
args = [x.to(device, non_blocking=True) for x in args] |
|
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'): |
|
ps, loss = model(*args) |
|
N = args[0].shape[0] |
|
val_loss += loss.mean().item() * N |
|
val_samples += N |
|
val_loss = val_loss / val_samples |
|
with record_function("model.train"): |
|
model.train() |
|
with record_function("plotting"): |
|
visual.add_data(it, scheduler.get_last_lr(), avg_train_loss, val_loss) |
|
|
|
if it >= next_table_it: |
|
visual.add_table_row(it, avg_train_loss, val_loss) |
|
next_table_it += table_row_every_iters |
|
|
|
it += bs |
|
visual.on_iter(bar, it, avg_train_loss, val_loss) |
|
except KeyboardInterrupt: |
|
mb.write(f"interrupted") |
|
mb.show() |
|
pass |
|
finally: |
|
visual.add_table_row(it, avg_train_loss, val_loss) |
|
mb.show() |
|
visual.hide() |
|
|