|
|
|
|
|
|
|
__all__ = [] |
|
|
|
|
|
import io |
|
import time |
|
import random |
|
from pathlib import Path |
|
|
|
from fastprogress import progress_bar, master_bar |
|
import fastprogress |
|
import wandb |
|
|
|
import numpy as np |
|
import pylab as plt |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data.dataloader import DataLoader |
|
from torch.profiler import record_function |
|
|
|
|
|
import lightning.pytorch as pl |
|
import math |
|
|
|
class TrainingTask(pl.LightningModule): |
|
def __init__(self, model, model_hparams=None): |
|
super().__init__() |
|
self.model = model |
|
self.model_hparams = model_hparams |
|
|
|
def on_fit_start(self): |
|
if getattr(self.model, 'setup'): |
|
self.model.setup(self.device) |
|
|
|
def configure_optimizers(self): |
|
""" Initialize AdamW optimizer""" |
|
lr = self.model_hparams['lr0'] |
|
weight_decay = self.model_hparams['weight_decay'] |
|
|
|
all_params = set(model.parameters()) |
|
customized_params = set() |
|
groups = [] |
|
group_map = {} |
|
for name,m in model.named_modules(): |
|
if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'): |
|
customized_params |= set(m.parameters()) |
|
m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay |
|
m_lr = lr * getattr(m, 'lr_scale', 1) |
|
group = group_map.get((m_wd, m_lr), None) |
|
if not group: |
|
group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr} |
|
groups.append(group) |
|
group_map[(m_wd, m_lr)] = group |
|
group['params'] += m.parameters() |
|
group['names'].append(name) |
|
|
|
other_params = all_params - customized_params |
|
|
|
param_groups = groups + [ |
|
{"names": ["other"], "params": list(other_params), "weight_decay": weight_decay }, |
|
] |
|
|
|
optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), params=param_groups) |
|
|
|
|
|
def num_steps_per_epoch() -> int: |
|
"""Get number of steps""" |
|
|
|
dataset = self.trainer.fit_loop._data_source.dataloader() |
|
dataset_size = len(dataset) |
|
|
|
num_steps = math.ceil(dataset_size / self.trainer.accumulate_grad_batches) |
|
return num_steps |
|
|
|
total_steps = self.model_hparams['epochs'] * num_steps_per_epoch() |
|
self.model_hparams['pct_start'] = min(0.3, self.model_hparams['warmup_steps'] / total_steps) |
|
|
|
print(f"{self.model_hparams['epochs']=} epochs x {num_steps_per_epoch()=} steps") |
|
|
|
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
optimizer, |
|
pct_start=self.model_hparams['pct_start'], |
|
max_lr=[pg.get('lr', lr) for pg in param_groups], |
|
steps_per_epoch=num_steps_per_epoch(), |
|
epochs=int(self.model_hparams['epochs']), |
|
final_div_factor=25 |
|
) |
|
|
|
return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}] |
|
|
|
def training_step(self, train_batch, batch_idx): |
|
train_logits, train_loss = self.model.forward(*train_batch) |
|
|
|
self.log("train_loss", train_loss, sync_dist=True) |
|
return train_loss |
|
|
|
def validation_step(self, val_batch, batch_idx): |
|
val_logits, val_loss = self.model.forward(*val_batch) |
|
|
|
self.log("val_loss", val_loss, sync_dist=True) |
|
return val_loss |
|
|
|
def on_validation_epoch_end(self): |
|
if hasattr(self.model, 'get_metrics'): |
|
self.log_dict({'metrics/'+k:v for k,v in self.model.get_metrics().items()}, sync_dist=True) |
|
|
|
def test_step(self, val_batch, batch_idx): |
|
test_logits, test_loss = self.model.forward(*val_batch) |
|
|
|
self.log("test_loss", test_loss, sync_dist=True) |
|
return test_loss |
|
|
|
|
|
from fastcore.script import anno_parser |
|
import shlex |
|
|
|
|
|
|
|
def parse_and_call(name, fun, args, kwargs={}, log_to_wandb=True): |
|
p = anno_parser(fun) |
|
args = p.parse_args(args).__dict__ |
|
args.pop('xtra'); args.pop('pdb') |
|
args.update({k:v for k, v in kwargs.items()}) |
|
if log_to_wandb and type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config: |
|
wandb_logger.experiment.config[name] = {k:v for k,v in args.items() if k not in ['dataset', 'tunables']} |
|
return fun(**args) |
|
|
|
|
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--task', type=str, help='Task to train') |
|
parser.add_argument('--seed', type=int, default=0, help='Global training seed') |
|
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') |
|
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)') |
|
parser.add_argument('--input-dir', type=str, default='', help='input data path') |
|
parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints/", help="directory to save the checkpoints") |
|
parser.add_argument('--epochs', type=int, default=10, help='total training epochs') |
|
parser.add_argument('--validate-every-n-steps', type=int, default=500, help='how training steps to run between validations') |
|
parser.add_argument('--weight-decay', type=float, default=1e-2, help='optimizer weight decay') |
|
parser.add_argument('--lr0', type=float, default=1e-4, help='optimizer initial learning rate') |
|
parser.add_argument('--clip-gradient-norm', type=float, default=None, help='enable gradient norm clipping') |
|
parser.add_argument('--accumulate-grad-batches', type=int, default=1, help='perform the optimizer step only after going through several batches of samples') |
|
parser.add_argument('--precision', type=str, default="16-mixed", help="floating point precision") |
|
parser.add_argument('--warmup-steps', type=int, default=10000, help='total number steps during which the learning rate rises (defaults to 10k updates)') |
|
parser.add_argument('--tunables', type=str, default="", help='tunable hyperparameters') |
|
parser.add_argument('--resume-from', type=Path, default=None, help='resume training from the given checkpoint') |
|
parser.add_argument('--strategy', type=str, default='ddp', help='distributed training strategy') |
|
parser.add_argument('--wandb-suffix', type=str, default=None, help='W&B project name suffix') |
|
parser.add_argument('--wandb-task-name', type=str, default=None, help='Task name for the W&B project name') |
|
|
|
args = parser.parse_args().__dict__ |
|
|
|
task_args: list = shlex.split(args.pop("task")) |
|
task_name, task_args = task_args[0], task_args[1:] |
|
input_args: list = shlex.split(args.pop("input_dir")) |
|
checkpoint_dir: str = args.pop("checkpoint_dir") |
|
num_workers: int = args.pop("workers") |
|
batch_size: int = args.pop("batch_size") |
|
epochs: int = args.pop("epochs") |
|
tunables_args: list = shlex.split(args.pop("tunables")) |
|
|
|
hyp_params = {} |
|
hyp_params['batch_size'] = batch_size |
|
hyp_params['warmup_steps'] = args['warmup_steps'] |
|
hyp_params['weight_decay'] = args['weight_decay'] |
|
hyp_params['clip_gradient_norm'] = args['clip_gradient_norm'] |
|
hyp_params['accumulate_grad_batches'] = args['accumulate_grad_batches'] |
|
hyp_params['precision'] = args['precision'] |
|
hyp_params['lr0'] = args['lr0'] |
|
hyp_params['epochs'] = epochs |
|
hyp_params['strategy'] = args['strategy'] |
|
|
|
|
|
from lightning.pytorch.loggers import WandbLogger |
|
from lightning.pytorch.callbacks import LearningRateMonitor |
|
import datetime |
|
import webdataset as wds |
|
import importlib |
|
|
|
torch.set_float32_matmul_precision('medium') |
|
|
|
project = f"WhisperSpeech-{args['wandb_task_name'] or task_name}" |
|
if args['wandb_suffix']: |
|
project += "-"+args['wandb_suffix'] |
|
|
|
wandb_logger = WandbLogger(project=project) |
|
|
|
ckpt_callback = pl.callbacks.ModelCheckpoint( |
|
dirpath=f'{task_name}-{epochs}e', |
|
filename=task_name+"-{epoch}-{step}-{val_loss:.2f}", |
|
monitor="val_loss", |
|
save_top_k=4, |
|
train_time_interval=datetime.timedelta(minutes=5), |
|
) |
|
|
|
lr_monitor_callback = LearningRateMonitor(logging_interval='step') |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
task = importlib.import_module("whisperspeech."+task_name) |
|
|
|
train_ds, val_ds = parse_and_call('dataset', task.load_datasets, input_args) |
|
|
|
tunables = None |
|
if hasattr(task, "Tunables"): |
|
import dataclasses |
|
tunables = parse_and_call('tunables', task.Tunables, tunables_args, log_to_wandb=False) |
|
if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config: |
|
wandb_logger.experiment.config['tunables'] = dataclasses.asdict(tunables) |
|
|
|
for name in ["lr0", "clip_gradient_norm", "weight_decay", "warmup_steps"]: |
|
val = getattr(tunables, name, None) |
|
if val is not None: hyp_params[name] = val |
|
|
|
if isinstance(train_ds, torch.utils.data.IterableDataset): |
|
dl_batch_size, dl_shuffle = None, False |
|
pin_memory = False |
|
else: |
|
dl_batch_size, dl_shuffle = batch_size, True |
|
pin_memory = True |
|
|
|
val_loader = wds.WebLoader(val_ds, |
|
batch_size=dl_batch_size, |
|
num_workers=num_workers, |
|
drop_last=False, |
|
pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(val_ds.total_samples // batch_size) |
|
|
|
train_loader = wds.WebLoader(train_ds, |
|
batch_size=dl_batch_size, |
|
num_workers=num_workers, |
|
drop_last=False, |
|
shuffle=dl_shuffle, |
|
pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(train_ds.total_samples // batch_size) |
|
|
|
model_kwargs = dict(dataset=train_ds) |
|
if tunables is not None: model_kwargs['tunables'] = tunables |
|
model = parse_and_call('model', task.make_model, task_args, model_kwargs) |
|
|
|
task = TrainingTask(model, model_hparams=hyp_params) |
|
|
|
trainer = pl.Trainer(strategy=hyp_params['strategy'], |
|
max_epochs=hyp_params['epochs'], |
|
accelerator="gpu", |
|
profiler="simple", |
|
precision=hyp_params['precision'], |
|
gradient_clip_val=hyp_params['clip_gradient_norm'], |
|
accumulate_grad_batches=hyp_params['accumulate_grad_batches'], |
|
val_check_interval=args.pop("validate_every_n_steps"), |
|
enable_checkpointing=True, |
|
logger=wandb_logger, |
|
callbacks=[ckpt_callback, lr_monitor_callback]) |
|
|
|
if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config: |
|
wandb_logger.experiment.config.update(hyp_params) |
|
|
|
kwargs = {} |
|
if 'resume_from' in args: |
|
kwargs['ckpt_path'] = args['resume_from'] |
|
trainer.fit(model=task, train_dataloaders=train_loader, val_dataloaders=val_loader, **kwargs) |
|
|