Spaces:
Configuration error
Configuration error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Main training loop.""" | |
import logging | |
from dora import get_xp | |
from dora.utils import write_and_rename | |
from dora.log import LogProgress, bold | |
import torch | |
import torch.nn.functional as F | |
from . import augment, distrib, states, pretrained | |
from .apply import apply_model | |
from .ema import ModelEMA | |
from .evaluate import evaluate, new_sdr | |
from .svd import svd_penalty | |
from .utils import pull_metric, EMA | |
logger = logging.getLogger(__name__) | |
def _summary(metrics): | |
return " | ".join(f"{key.capitalize()}={val}" for key, val in metrics.items()) | |
class Solver(object): | |
def __init__(self, loaders, model, optimizer, args): | |
self.args = args | |
self.loaders = loaders | |
self.model = model | |
self.optimizer = optimizer | |
self.quantizer = states.get_quantizer(self.model, args.quant, self.optimizer) | |
self.dmodel = distrib.wrap(model) | |
self.device = next(iter(self.model.parameters())).device | |
# Exponential moving average of the model, either updated every batch or epoch. | |
# The best model from all the EMAs and the original one is kept based on the valid | |
# loss for the final best model. | |
self.emas = {'batch': [], 'epoch': []} | |
for kind in self.emas.keys(): | |
decays = getattr(args.ema, kind) | |
device = self.device if kind == 'batch' else 'cpu' | |
if decays: | |
for decay in decays: | |
self.emas[kind].append(ModelEMA(self.model, decay, device=device)) | |
# data augment | |
augments = [augment.Shift(shift=int(args.dset.samplerate * args.dset.shift), | |
same=args.augment.shift_same)] | |
if args.augment.flip: | |
augments += [augment.FlipChannels(), augment.FlipSign()] | |
for aug in ['scale', 'remix']: | |
kw = getattr(args.augment, aug) | |
if kw.proba: | |
augments.append(getattr(augment, aug.capitalize())(**kw)) | |
self.augment = torch.nn.Sequential(*augments) | |
xp = get_xp() | |
self.folder = xp.folder | |
# Checkpoints | |
self.checkpoint_file = xp.folder / 'checkpoint.th' | |
self.best_file = xp.folder / 'best.th' | |
logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve()) | |
self.best_state = None | |
self.best_changed = False | |
self.link = xp.link | |
self.history = self.link.history | |
self._reset() | |
def _serialize(self, epoch): | |
package = {} | |
package['state'] = self.model.state_dict() | |
package['optimizer'] = self.optimizer.state_dict() | |
package['history'] = self.history | |
package['best_state'] = self.best_state | |
package['args'] = self.args | |
for kind, emas in self.emas.items(): | |
for k, ema in enumerate(emas): | |
package[f'ema_{kind}_{k}'] = ema.state_dict() | |
with write_and_rename(self.checkpoint_file) as tmp: | |
torch.save(package, tmp) | |
save_every = self.args.save_every | |
if save_every and (epoch + 1) % save_every == 0 and epoch + 1 != self.args.epochs: | |
with write_and_rename(self.folder / f'checkpoint_{epoch + 1}.th') as tmp: | |
torch.save(package, tmp) | |
if self.best_changed: | |
# Saving only the latest best model. | |
with write_and_rename(self.best_file) as tmp: | |
package = states.serialize_model(self.model, self.args) | |
package['state'] = self.best_state | |
torch.save(package, tmp) | |
self.best_changed = False | |
def _reset(self): | |
"""Reset state of the solver, potentially using checkpoint.""" | |
if self.checkpoint_file.exists(): | |
logger.info(f'Loading checkpoint model: {self.checkpoint_file}') | |
package = torch.load(self.checkpoint_file, 'cpu') | |
self.model.load_state_dict(package['state']) | |
self.optimizer.load_state_dict(package['optimizer']) | |
self.history[:] = package['history'] | |
self.best_state = package['best_state'] | |
for kind, emas in self.emas.items(): | |
for k, ema in enumerate(emas): | |
ema.load_state_dict(package[f'ema_{kind}_{k}']) | |
elif self.args.continue_pretrained: | |
model = pretrained.get_model( | |
name=self.args.continue_pretrained, | |
repo=self.args.pretrained_repo) | |
self.model.load_state_dict(model.state_dict()) | |
elif self.args.continue_from: | |
name = 'checkpoint.th' | |
root = self.folder.parent | |
cf = root / str(self.args.continue_from) / name | |
logger.info("Loading from %s", cf) | |
package = torch.load(cf, 'cpu') | |
self.best_state = package['best_state'] | |
if self.args.continue_best: | |
self.model.load_state_dict(package['best_state'], strict=False) | |
else: | |
self.model.load_state_dict(package['state'], strict=False) | |
if self.args.continue_opt: | |
self.optimizer.load_state_dict(package['optimizer']) | |
def _format_train(self, metrics: dict) -> dict: | |
"""Formatting for train/valid metrics.""" | |
losses = { | |
'loss': format(metrics['loss'], ".4f"), | |
'reco': format(metrics['reco'], ".4f"), | |
} | |
if 'nsdr' in metrics: | |
losses['nsdr'] = format(metrics['nsdr'], ".3f") | |
if self.quantizer is not None: | |
losses['ms'] = format(metrics['ms'], ".2f") | |
if 'grad' in metrics: | |
losses['grad'] = format(metrics['grad'], ".4f") | |
if 'best' in metrics: | |
losses['best'] = format(metrics['best'], '.4f') | |
if 'bname' in metrics: | |
losses['bname'] = metrics['bname'] | |
if 'penalty' in metrics: | |
losses['penalty'] = format(metrics['penalty'], ".4f") | |
if 'hloss' in metrics: | |
losses['hloss'] = format(metrics['hloss'], ".4f") | |
return losses | |
def _format_test(self, metrics: dict) -> dict: | |
"""Formatting for test metrics.""" | |
losses = {} | |
if 'sdr' in metrics: | |
losses['sdr'] = format(metrics['sdr'], '.3f') | |
if 'nsdr' in metrics: | |
losses['nsdr'] = format(metrics['nsdr'], '.3f') | |
for source in self.model.sources: | |
key = f'sdr_{source}' | |
if key in metrics: | |
losses[key] = format(metrics[key], '.3f') | |
key = f'nsdr_{source}' | |
if key in metrics: | |
losses[key] = format(metrics[key], '.3f') | |
return losses | |
def train(self): | |
# Optimizing the model | |
if self.history: | |
logger.info("Replaying metrics from previous run") | |
for epoch, metrics in enumerate(self.history): | |
formatted = self._format_train(metrics['train']) | |
logger.info( | |
bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}')) | |
formatted = self._format_train(metrics['valid']) | |
logger.info( | |
bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}')) | |
if 'test' in metrics: | |
formatted = self._format_test(metrics['test']) | |
if formatted: | |
logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}")) | |
epoch = 0 | |
for epoch in range(len(self.history), self.args.epochs): | |
# Train one epoch | |
self.model.train() # Turn on BatchNorm & Dropout | |
metrics = {} | |
logger.info('-' * 70) | |
logger.info("Training...") | |
metrics['train'] = self._run_one_epoch(epoch) | |
formatted = self._format_train(metrics['train']) | |
logger.info( | |
bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}')) | |
# Cross validation | |
logger.info('-' * 70) | |
logger.info('Cross validation...') | |
self.model.eval() # Turn off Batchnorm & Dropout | |
with torch.no_grad(): | |
valid = self._run_one_epoch(epoch, train=False) | |
bvalid = valid | |
bname = 'main' | |
state = states.copy_state(self.model.state_dict()) | |
metrics['valid'] = {} | |
metrics['valid']['main'] = valid | |
key = self.args.test.metric | |
for kind, emas in self.emas.items(): | |
for k, ema in enumerate(emas): | |
with ema.swap(): | |
valid = self._run_one_epoch(epoch, train=False) | |
name = f'ema_{kind}_{k}' | |
metrics['valid'][name] = valid | |
a = valid[key] | |
b = bvalid[key] | |
if key.startswith('nsdr'): | |
a = -a | |
b = -b | |
if a < b: | |
bvalid = valid | |
state = ema.state | |
bname = name | |
metrics['valid'].update(bvalid) | |
metrics['valid']['bname'] = bname | |
valid_loss = metrics['valid'][key] | |
mets = pull_metric(self.link.history, f'valid.{key}') + [valid_loss] | |
if key.startswith('nsdr'): | |
best_loss = max(mets) | |
else: | |
best_loss = min(mets) | |
metrics['valid']['best'] = best_loss | |
if self.args.svd.penalty > 0: | |
kw = dict(self.args.svd) | |
kw.pop('penalty') | |
with torch.no_grad(): | |
penalty = svd_penalty(self.model, exact=True, **kw) | |
metrics['valid']['penalty'] = penalty | |
formatted = self._format_train(metrics['valid']) | |
logger.info( | |
bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}')) | |
# Save the best model | |
if valid_loss == best_loss or self.args.dset.train_valid: | |
logger.info(bold('New best valid loss %.4f'), valid_loss) | |
self.best_state = states.copy_state(state) | |
self.best_changed = True | |
# Eval model every `test.every` epoch or on last epoch | |
should_eval = (epoch + 1) % self.args.test.every == 0 | |
is_last = epoch == self.args.epochs - 1 | |
# # Tries to detect divergence in a reliable way and finish job | |
# # not to waste compute. | |
# # Commented out as this was super specific to the MDX competition. | |
# reco = metrics['valid']['main']['reco'] | |
# div = epoch >= 180 and reco > 0.18 | |
# div = div or epoch >= 100 and reco > 0.25 | |
# div = div and self.args.optim.loss == 'l1' | |
# if div: | |
# logger.warning("Finishing training early because valid loss is too high.") | |
# is_last = True | |
if should_eval or is_last: | |
# Evaluate on the testset | |
logger.info('-' * 70) | |
logger.info('Evaluating on the test set...') | |
# We switch to the best known model for testing | |
if self.args.test.best: | |
state = self.best_state | |
else: | |
state = states.copy_state(self.model.state_dict()) | |
compute_sdr = self.args.test.sdr and is_last | |
with states.swap_state(self.model, state): | |
with torch.no_grad(): | |
metrics['test'] = evaluate(self, compute_sdr=compute_sdr) | |
formatted = self._format_test(metrics['test']) | |
logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}")) | |
self.link.push_metrics(metrics) | |
if distrib.rank == 0: | |
# Save model each epoch | |
self._serialize(epoch) | |
logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve()) | |
if is_last: | |
break | |
def _run_one_epoch(self, epoch, train=True): | |
args = self.args | |
data_loader = self.loaders['train'] if train else self.loaders['valid'] | |
if distrib.world_size > 1 and train: | |
data_loader.sampler.set_epoch(epoch) | |
label = ["Valid", "Train"][train] | |
name = label + f" | Epoch {epoch + 1}" | |
total = len(data_loader) | |
if args.max_batches: | |
total = min(total, args.max_batches) | |
logprog = LogProgress(logger, data_loader, total=total, | |
updates=self.args.misc.num_prints, name=name) | |
averager = EMA() | |
for idx, sources in enumerate(logprog): | |
sources = sources.to(self.device) | |
if train: | |
sources = self.augment(sources) | |
mix = sources.sum(dim=1) | |
else: | |
mix = sources[:, 0] | |
sources = sources[:, 1:] | |
if not train and self.args.valid_apply: | |
estimate = apply_model(self.model, mix, split=self.args.test.split, overlap=0) | |
else: | |
estimate = self.dmodel(mix) | |
if train and hasattr(self.model, 'transform_target'): | |
sources = self.model.transform_target(mix, sources) | |
assert estimate.shape == sources.shape, (estimate.shape, sources.shape) | |
dims = tuple(range(2, sources.dim())) | |
if args.optim.loss == 'l1': | |
loss = F.l1_loss(estimate, sources, reduction='none') | |
loss = loss.mean(dims).mean(0) | |
reco = loss | |
elif args.optim.loss == 'mse': | |
loss = F.mse_loss(estimate, sources, reduction='none') | |
loss = loss.mean(dims) | |
reco = loss**0.5 | |
reco = reco.mean(0) | |
else: | |
raise ValueError(f"Invalid loss {self.args.loss}") | |
weights = torch.tensor(args.weights).to(sources) | |
loss = (loss * weights).sum() / weights.sum() | |
ms = 0 | |
if self.quantizer is not None: | |
ms = self.quantizer.model_size() | |
if args.quant.diffq: | |
loss += args.quant.diffq * ms | |
losses = {} | |
losses['reco'] = (reco * weights).sum() / weights.sum() | |
losses['ms'] = ms | |
if not train: | |
nsdrs = new_sdr(sources, estimate.detach()).mean(0) | |
total = 0 | |
for source, nsdr, w in zip(self.model.sources, nsdrs, weights): | |
losses[f'nsdr_{source}'] = nsdr | |
total += w * nsdr | |
losses['nsdr'] = total / weights.sum() | |
if train and args.svd.penalty > 0: | |
kw = dict(args.svd) | |
kw.pop('penalty') | |
penalty = svd_penalty(self.model, **kw) | |
losses['penalty'] = penalty | |
loss += args.svd.penalty * penalty | |
losses['loss'] = loss | |
for k, source in enumerate(self.model.sources): | |
losses[f'reco_{source}'] = reco[k] | |
# optimize model in training mode | |
if train: | |
loss.backward() | |
grad_norm = 0 | |
grads = [] | |
for p in self.model.parameters(): | |
if p.grad is not None: | |
grad_norm += p.grad.data.norm()**2 | |
grads.append(p.grad.data) | |
losses['grad'] = grad_norm ** 0.5 | |
if args.optim.clip_grad: | |
torch.nn.utils.clip_grad_norm_( | |
self.model.parameters(), | |
args.optim.clip_grad) | |
if self.args.flag == 'uns': | |
for n, p in self.model.named_parameters(): | |
if p.grad is None: | |
print('no grad', n) | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
for ema in self.emas['batch']: | |
ema.update() | |
losses = averager(losses) | |
logs = self._format_train(losses) | |
logprog.update(**logs) | |
# Just in case, clear some memory | |
del loss, estimate, reco, ms | |
if args.max_batches == idx: | |
break | |
if self.args.debug and train: | |
break | |
if self.args.flag == 'debug': | |
break | |
if train: | |
for ema in self.emas['epoch']: | |
ema.update() | |
return distrib.average(losses, idx + 1) | |