|
|
|
|
|
|
|
|
|
|
|
|
|
import typing as tp
|
|
|
|
import flashy
|
|
import julius
|
|
import omegaconf
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from . import builders
|
|
from . import base
|
|
from .. import models
|
|
from ..modules.diffusion_schedule import NoiseSchedule
|
|
from ..metrics import RelativeVolumeMel
|
|
from ..models.builders import get_processor
|
|
from ..utils.samples.manager import SampleManager
|
|
from ..solvers.compression import CompressionSolver
|
|
|
|
|
|
class PerStageMetrics:
|
|
"""Handle prompting the metrics per stage.
|
|
It outputs the metrics per range of diffusion states.
|
|
e.g. avg loss when t in [250, 500]
|
|
"""
|
|
def __init__(self, num_steps: int, num_stages: int = 4):
|
|
self.num_steps = num_steps
|
|
self.num_stages = num_stages
|
|
|
|
def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]):
|
|
if type(step) is int:
|
|
stage = int((step / self.num_steps) * self.num_stages)
|
|
return {f"{name}_{stage}": loss for name, loss in losses.items()}
|
|
elif type(step) is torch.Tensor:
|
|
stage_tensor = ((step / self.num_steps) * self.num_stages).long()
|
|
out: tp.Dict[str, float] = {}
|
|
for stage_idx in range(self.num_stages):
|
|
mask = (stage_tensor == stage_idx)
|
|
N = mask.sum()
|
|
stage_out = {}
|
|
if N > 0:
|
|
for name, loss in losses.items():
|
|
stage_loss = (mask * loss).sum() / N
|
|
stage_out[f"{name}_{stage_idx}"] = stage_loss
|
|
out = {**out, **stage_out}
|
|
return out
|
|
|
|
|
|
class DataProcess:
|
|
"""Apply filtering or resampling.
|
|
|
|
Args:
|
|
initial_sr (int): Initial sample rate.
|
|
target_sr (int): Target sample rate.
|
|
use_resampling: Whether to use resampling or not.
|
|
use_filter (bool):
|
|
n_bands (int): Number of bands to consider.
|
|
idx_band (int):
|
|
device (torch.device or str):
|
|
cutoffs ():
|
|
boost (bool):
|
|
"""
|
|
def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False,
|
|
use_filter: bool = False, n_bands: int = 4,
|
|
idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False):
|
|
"""Apply filtering or resampling
|
|
Args:
|
|
initial_sr (int): sample rate of the dataset
|
|
target_sr (int): sample rate after resampling
|
|
use_resampling (bool): whether or not performs resampling
|
|
use_filter (bool): when True filter the data to keep only one frequency band
|
|
n_bands (int): Number of bands used
|
|
cuts (none or list): The cutoff frequencies of the band filtering
|
|
if None then we use mel scale bands.
|
|
idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs
|
|
boost (bool): make the data scale match our music dataset.
|
|
"""
|
|
assert idx_band < n_bands
|
|
self.idx_band = idx_band
|
|
if use_filter:
|
|
if cutoffs is not None:
|
|
self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device)
|
|
else:
|
|
self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device)
|
|
self.use_filter = use_filter
|
|
self.use_resampling = use_resampling
|
|
self.target_sr = target_sr
|
|
self.initial_sr = initial_sr
|
|
self.boost = boost
|
|
|
|
def process_data(self, x, metric=False):
|
|
if x is None:
|
|
return None
|
|
if self.boost:
|
|
x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4)
|
|
x * 0.22
|
|
if self.use_filter and not metric:
|
|
x = self.filter(x)[self.idx_band]
|
|
if self.use_resampling:
|
|
x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr)
|
|
return x
|
|
|
|
def inverse_process(self, x):
|
|
"""Upsampling only."""
|
|
if self.use_resampling:
|
|
x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr)
|
|
return x
|
|
|
|
|
|
class DiffusionSolver(base.StandardSolver):
|
|
"""Solver for compression task.
|
|
|
|
The diffusion task allows for MultiBand diffusion model training.
|
|
|
|
Args:
|
|
cfg (DictConfig): Configuration.
|
|
"""
|
|
def __init__(self, cfg: omegaconf.DictConfig):
|
|
super().__init__(cfg)
|
|
self.cfg = cfg
|
|
self.device = cfg.device
|
|
self.sample_rate: int = self.cfg.sample_rate
|
|
self.codec_model = CompressionSolver.model_from_checkpoint(
|
|
cfg.compression_model_checkpoint, device=self.device)
|
|
|
|
self.codec_model.set_num_codebooks(cfg.n_q)
|
|
assert self.codec_model.sample_rate == self.cfg.sample_rate, (
|
|
f"Codec model sample rate is {self.codec_model.sample_rate} but "
|
|
f"Solver sample rate is {self.cfg.sample_rate}."
|
|
)
|
|
assert self.codec_model.sample_rate == self.sample_rate, \
|
|
f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \
|
|
"don't match."
|
|
|
|
self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate)
|
|
self.register_stateful('sample_processor')
|
|
self.sample_processor.to(self.device)
|
|
|
|
self.schedule = NoiseSchedule(
|
|
**cfg.schedule, device=self.device, sample_processor=self.sample_processor)
|
|
|
|
self.eval_metric: tp.Optional[torch.nn.Module] = None
|
|
|
|
self.rvm = RelativeVolumeMel()
|
|
self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr,
|
|
use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs,
|
|
use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands,
|
|
idx_band=cfg.filter.idx_band, device=self.device)
|
|
|
|
@property
|
|
def best_metric_name(self) -> tp.Optional[str]:
|
|
if self._current_stage == "evaluate":
|
|
return 'rvm'
|
|
else:
|
|
return 'loss'
|
|
|
|
@torch.no_grad()
|
|
def get_condition(self, wav: torch.Tensor) -> torch.Tensor:
|
|
codes, scale = self.codec_model.encode(wav)
|
|
assert scale is None, "Scaled compression models not supported."
|
|
emb = self.codec_model.decode_latent(codes)
|
|
return emb
|
|
|
|
def build_model(self):
|
|
"""Build model and optimizer as well as optional Exponential Moving Average of the model.
|
|
"""
|
|
|
|
self.model = models.builders.get_diffusion_model(self.cfg).to(self.device)
|
|
self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
|
|
self.register_stateful('model', 'optimizer')
|
|
self.register_best_state('model')
|
|
self.register_ema('model')
|
|
|
|
def build_dataloaders(self):
|
|
"""Build audio dataloaders for each stage."""
|
|
self.dataloaders = builders.get_audio_datasets(self.cfg)
|
|
|
|
def show(self):
|
|
|
|
raise NotImplementedError()
|
|
|
|
def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
|
|
"""Perform one training or valid step on a given batch."""
|
|
x = batch.to(self.device)
|
|
loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss
|
|
|
|
condition = self.get_condition(x)
|
|
sample = self.data_processor.process_data(x)
|
|
|
|
input_, target, step = self.schedule.get_training_item(sample,
|
|
tensor_step=self.cfg.schedule.variable_step_batch)
|
|
out = self.model(input_, step, condition=condition).sample
|
|
|
|
base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2))
|
|
reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2))
|
|
loss = base_loss / reference_loss ** self.cfg.loss.norm_power
|
|
|
|
if self.is_training:
|
|
loss.mean().backward()
|
|
flashy.distrib.sync_model(self.model)
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
metrics = {
|
|
'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(),
|
|
}
|
|
metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step))
|
|
metrics.update({
|
|
'std_in': input_.std(), 'std_out': out.std()})
|
|
return metrics
|
|
|
|
def run_epoch(self):
|
|
|
|
self.rng = torch.Generator()
|
|
self.rng.manual_seed(1234 + self.epoch)
|
|
self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage)
|
|
|
|
super().run_epoch()
|
|
|
|
def evaluate(self):
|
|
"""Evaluate stage.
|
|
Runs audio reconstruction evaluation.
|
|
"""
|
|
self.model.eval()
|
|
evaluate_stage_name = f'{self.current_stage}'
|
|
loader = self.dataloaders['evaluate']
|
|
updates = len(loader)
|
|
lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates)
|
|
|
|
metrics = {}
|
|
n = 1
|
|
for idx, batch in enumerate(lp):
|
|
x = batch.to(self.device)
|
|
with torch.no_grad():
|
|
y_pred = self.regenerate(x)
|
|
|
|
y_pred = y_pred.cpu()
|
|
y = batch.cpu()
|
|
rvm = self.rvm(y_pred, y)
|
|
lp.update(**rvm)
|
|
if len(metrics) == 0:
|
|
metrics = rvm
|
|
else:
|
|
for key in rvm.keys():
|
|
metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1)
|
|
metrics = flashy.distrib.average_metrics(metrics)
|
|
return metrics
|
|
|
|
@torch.no_grad()
|
|
def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None):
|
|
"""Regenerate the given waveform."""
|
|
condition = self.get_condition(wav)
|
|
initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav))
|
|
result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition,
|
|
step_list=step_list)
|
|
result = self.data_processor.inverse_process(result)
|
|
return result
|
|
|
|
def generate(self):
|
|
"""Generate stage."""
|
|
sample_manager = SampleManager(self.xp)
|
|
self.model.eval()
|
|
generate_stage_name = f'{self.current_stage}'
|
|
|
|
loader = self.dataloaders['generate']
|
|
updates = len(loader)
|
|
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
|
|
|
|
for batch in lp:
|
|
reference, _ = batch
|
|
reference = reference.to(self.device)
|
|
estimate = self.regenerate(reference)
|
|
reference = reference.cpu()
|
|
estimate = estimate.cpu()
|
|
sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
|
|
flashy.distrib.barrier()
|
|
|