Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import torchmetrics | |
| import pytorch_lightning as pl | |
| from torch import Tensor, nn | |
| from torchaudio.models import HDemucs | |
| from audio_diffusion_pytorch import DiffusionModel | |
| from auraloss.time import SISDRLoss | |
| from auraloss.freq import MultiResolutionSTFTLoss | |
| from umx.openunmix.model import OpenUnmix, Separator | |
| from remfx.utils import FADLoss, spectrogram | |
| from remfx.tcn import TCN | |
| from remfx.utils import causal_crop | |
| from remfx.callbacks import log_wandb_audio_batch | |
| from einops import rearrange | |
| from remfx import effects | |
| import asteroid | |
| ALL_EFFECTS = effects.Pedalboard_Effects | |
| class RemFXChainInference(pl.LightningModule): | |
| def __init__(self, models, sample_rate, num_bins, effect_order): | |
| super().__init__() | |
| self.model = models | |
| self.mrstftloss = MultiResolutionSTFTLoss( | |
| n_bins=num_bins, sample_rate=sample_rate | |
| ) | |
| self.l1loss = nn.L1Loss() | |
| self.metrics = nn.ModuleDict( | |
| { | |
| "SISDR": SISDRLoss(), | |
| "STFT": MultiResolutionSTFTLoss(), | |
| } | |
| ) | |
| self.sample_rate = sample_rate | |
| self.effect_order = effect_order | |
| def forward(self, batch, batch_idx, order=None): | |
| x, y, _, rem_fx_labels = batch | |
| # Use chain of effects defined in config | |
| if order: | |
| effects_order = order | |
| else: | |
| effects_order = self.effect_order | |
| effects_present = [ | |
| [ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0] | |
| for effect_label in rem_fx_labels | |
| ] | |
| output = [] | |
| input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0) | |
| target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0) | |
| log_wandb_audio_batch( | |
| logger=self.logger, | |
| id="input_effected_audio", | |
| samples=input_samples.cpu(), | |
| sampling_rate=self.sample_rate, | |
| caption="Input Data", | |
| ) | |
| log_wandb_audio_batch( | |
| logger=self.logger, | |
| id="target_audio", | |
| samples=target_samples.cpu(), | |
| sampling_rate=self.sample_rate, | |
| caption="Target Data", | |
| ) | |
| with torch.no_grad(): | |
| for i, (elem, effects_list) in enumerate(zip(x, effects_present)): | |
| elem = elem.unsqueeze(0) # Add batch dim | |
| # Get the correct effect by search for names in effects_order | |
| effect_list_names = [effect.__name__ for effect in effects_list] | |
| effects = [ | |
| effect for effect in effects_order if effect in effect_list_names | |
| ] | |
| # log_wandb_audio_batch( | |
| # logger=self.logger, | |
| # id=f"{i}_Before", | |
| # samples=elem.cpu(), | |
| # sampling_rate=self.sample_rate, | |
| # caption=effects, | |
| # ) | |
| for effect in effects: | |
| # Sample the model | |
| elem = self.model[effect].model.sample(elem) | |
| # log_wandb_audio_batch( | |
| # logger=self.logger, | |
| # id=f"{i}_{effect}", | |
| # samples=elem.cpu(), | |
| # sampling_rate=self.sample_rate, | |
| # caption=effects, | |
| # ) | |
| # log_wandb_audio_batch( | |
| # logger=self.logger, | |
| # id=f"{i}_After", | |
| # samples=elem.cpu(), | |
| # sampling_rate=self.sample_rate, | |
| # caption=effects, | |
| # ) | |
| output.append(elem.squeeze(0)) | |
| output = torch.stack(output) | |
| output_samples = rearrange(output, "b c t -> c (b t)").unsqueeze(0) | |
| log_wandb_audio_batch( | |
| logger=self.logger, | |
| id="output_audio", | |
| samples=output_samples.cpu(), | |
| sampling_rate=self.sample_rate, | |
| caption="Output Data", | |
| ) | |
| loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100 | |
| return loss, output | |
| def test_step(self, batch, batch_idx): | |
| x, y, _, _ = batch # x, y = (B, C, T), (B, C, T) | |
| # Random order | |
| # random.shuffle(self.effect_order) | |
| loss, output = self.forward(batch, batch_idx, order=self.effect_order) | |
| # Crop target to match output | |
| if output.shape[-1] < y.shape[-1]: | |
| y = causal_crop(y, output.shape[-1]) | |
| self.log("test_loss", loss) | |
| # Metric logging | |
| with torch.no_grad(): | |
| for metric in self.metrics: | |
| # SISDR returns negative values, so negate them | |
| if metric == "SISDR": | |
| negate = -1 | |
| else: | |
| negate = 1 | |
| self.log( | |
| f"test_{metric}", # + "".join(self.effect_order).replace("RandomPedalboard", ""), | |
| negate * self.metrics[metric](output, y), | |
| on_step=False, | |
| on_epoch=True, | |
| logger=True, | |
| prog_bar=True, | |
| sync_dist=True, | |
| ) | |
| self.log( | |
| f"Input_{metric}", | |
| negate * self.metrics[metric](x, y), | |
| on_step=False, | |
| on_epoch=True, | |
| logger=True, | |
| prog_bar=True, | |
| sync_dist=True, | |
| ) | |
| return loss | |
| def sample(self, batch): | |
| return self.forward(batch, 0)[1] | |
| class RemFX(pl.LightningModule): | |
| def __init__( | |
| self, | |
| lr: float, | |
| lr_beta1: float, | |
| lr_beta2: float, | |
| lr_eps: float, | |
| lr_weight_decay: float, | |
| sample_rate: float, | |
| network: nn.Module, | |
| ): | |
| super().__init__() | |
| self.lr = lr | |
| self.lr_beta1 = lr_beta1 | |
| self.lr_beta2 = lr_beta2 | |
| self.lr_eps = lr_eps | |
| self.lr_weight_decay = lr_weight_decay | |
| self.sample_rate = sample_rate | |
| self.model = network | |
| self.metrics = nn.ModuleDict( | |
| { | |
| "SISDR": SISDRLoss(), | |
| "STFT": MultiResolutionSTFTLoss(), | |
| } | |
| ) | |
| # Log first batch metrics input vs output only once | |
| self.log_train_audio = True | |
| def device(self): | |
| return next(self.model.parameters()).device | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.AdamW( | |
| list(self.model.parameters()), | |
| lr=self.lr, | |
| betas=(self.lr_beta1, self.lr_beta2), | |
| eps=self.lr_eps, | |
| weight_decay=self.lr_weight_decay, | |
| ) | |
| lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer, | |
| [0.8 * self.trainer.max_steps, 0.95 * self.trainer.max_steps], | |
| gamma=0.1, | |
| ) | |
| return { | |
| "optimizer": optimizer, | |
| "lr_scheduler": { | |
| "scheduler": lr_scheduler, | |
| "monitor": "val_loss", | |
| "interval": "step", | |
| "frequency": 1, | |
| }, | |
| } | |
| def training_step(self, batch, batch_idx): | |
| return self.common_step(batch, batch_idx, mode="train") | |
| def validation_step(self, batch, batch_idx): | |
| return self.common_step(batch, batch_idx, mode="valid") | |
| def test_step(self, batch, batch_idx): | |
| return self.common_step(batch, batch_idx, mode="test") | |
| def common_step(self, batch, batch_idx, mode: str = "train"): | |
| x, y, _, _ = batch # x, y = (B, C, T), (B, C, T) | |
| loss, output = self.model((x, y)) | |
| # Crop target to match output | |
| target = y | |
| if output.shape[-1] < y.shape[-1]: | |
| target = causal_crop(y, output.shape[-1]) | |
| self.log(f"{mode}_loss", loss) | |
| # Metric logging | |
| with torch.no_grad(): | |
| for metric in self.metrics: | |
| # SISDR returns negative values, so negate them | |
| if metric == "SISDR": | |
| negate = -1 | |
| else: | |
| negate = 1 | |
| # Only Log FAD on test set | |
| if metric == "FAD" and mode != "test": | |
| continue | |
| self.log( | |
| f"{mode}_{metric}", | |
| negate * self.metrics[metric](output, target), | |
| on_step=False, | |
| on_epoch=True, | |
| logger=True, | |
| prog_bar=True, | |
| sync_dist=True, | |
| ) | |
| self.log( | |
| f"Input_{metric}", | |
| negate * self.metrics[metric](x, y), | |
| on_step=False, | |
| on_epoch=True, | |
| logger=True, | |
| prog_bar=True, | |
| sync_dist=True, | |
| ) | |
| return loss | |
| class OpenUnmixModel(nn.Module): | |
| def __init__( | |
| self, | |
| n_fft: int = 2048, | |
| hop_length: int = 512, | |
| n_channels: int = 1, | |
| alpha: float = 0.3, | |
| sample_rate: int = 22050, | |
| ): | |
| super().__init__() | |
| self.n_channels = n_channels | |
| self.n_fft = n_fft | |
| self.hop_length = hop_length | |
| self.alpha = alpha | |
| window = torch.hann_window(n_fft) | |
| self.register_buffer("window", window) | |
| self.num_bins = self.n_fft // 2 + 1 | |
| self.sample_rate = sample_rate | |
| self.model = OpenUnmix( | |
| nb_channels=self.n_channels, | |
| nb_bins=self.num_bins, | |
| ) | |
| self.separator = Separator( | |
| target_models={"other": self.model}, | |
| nb_channels=self.n_channels, | |
| sample_rate=self.sample_rate, | |
| n_fft=self.n_fft, | |
| n_hop=self.hop_length, | |
| ) | |
| self.mrstftloss = MultiResolutionSTFTLoss( | |
| n_bins=self.num_bins, sample_rate=self.sample_rate | |
| ) | |
| self.l1loss = nn.L1Loss() | |
| def forward(self, batch): | |
| x, target = batch | |
| X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha) | |
| Y = self.model(X) | |
| sep_out = self.separator(x).squeeze(1) | |
| loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target) * 100 | |
| return loss, sep_out | |
| def sample(self, x: Tensor) -> Tensor: | |
| return self.separator(x).squeeze(1) | |
| class DemucsModel(nn.Module): | |
| def __init__(self, sample_rate, **kwargs) -> None: | |
| super().__init__() | |
| self.model = HDemucs(**kwargs) | |
| self.num_bins = kwargs["nfft"] // 2 + 1 | |
| self.mrstftloss = MultiResolutionSTFTLoss( | |
| n_bins=self.num_bins, sample_rate=sample_rate | |
| ) | |
| self.l1loss = nn.L1Loss() | |
| def forward(self, batch): | |
| x, target = batch | |
| output = self.model(x).squeeze(1) | |
| loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100 | |
| return loss, output | |
| def sample(self, x: Tensor) -> Tensor: | |
| return self.model(x).squeeze(1) | |
| class DiffusionGenerationModel(nn.Module): | |
| def __init__(self, n_channels: int = 1): | |
| super().__init__() | |
| self.model = DiffusionModel(in_channels=n_channels) | |
| def forward(self, batch): | |
| x, target = batch | |
| sampled_out = self.model.sample(x) | |
| return self.model(x), sampled_out | |
| def sample(self, x: Tensor, num_steps: int = 10) -> Tensor: | |
| noise = torch.randn(x.shape).to(x) | |
| return self.model.sample(noise, num_steps=num_steps) | |
| class DPTNetModel(nn.Module): | |
| def __init__(self, sample_rate, num_bins, **kwargs): | |
| super().__init__() | |
| self.model = asteroid.models.dptnet.DPTNet(**kwargs) | |
| self.num_bins = num_bins | |
| self.mrstftloss = MultiResolutionSTFTLoss( | |
| n_bins=self.num_bins, sample_rate=sample_rate | |
| ) | |
| self.l1loss = nn.L1Loss() | |
| def forward(self, batch): | |
| x, target = batch | |
| output = self.model(x.squeeze(1)) | |
| loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100 | |
| return loss, output | |
| def sample(self, x: Tensor) -> Tensor: | |
| return self.model(x.squeeze(1)) | |
| class DCUNetModel(nn.Module): | |
| def __init__(self, sample_rate, num_bins, **kwargs): | |
| super().__init__() | |
| self.model = asteroid.models.DCUNet(**kwargs) | |
| self.mrstftloss = MultiResolutionSTFTLoss( | |
| n_bins=num_bins, sample_rate=sample_rate | |
| ) | |
| self.l1loss = nn.L1Loss() | |
| def forward(self, batch): | |
| x, target = batch | |
| output = self.model(x.squeeze(1)) # B x T | |
| # Crop target to match output | |
| if output.shape[-1] < target.shape[-1]: | |
| target = causal_crop(target, output.shape[-1]) | |
| loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100 | |
| return loss, output | |
| def sample(self, x: Tensor) -> Tensor: | |
| output = self.model(x.squeeze(1)) # B x T | |
| return output | |
| class TCNModel(nn.Module): | |
| def __init__(self, sample_rate, num_bins, **kwargs): | |
| super().__init__() | |
| self.model = TCN(**kwargs) | |
| self.mrstftloss = MultiResolutionSTFTLoss( | |
| n_bins=num_bins, sample_rate=sample_rate | |
| ) | |
| self.l1loss = nn.L1Loss() | |
| def forward(self, batch): | |
| x, target = batch | |
| output = self.model(x) # B x 1 x T | |
| # Crop target to match output | |
| if output.shape[-1] < target.shape[-1]: | |
| target = causal_crop(target, output.shape[-1]) | |
| loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100 | |
| return loss, output | |
| def sample(self, x: Tensor) -> Tensor: | |
| output = self.model(x) # B x 1 x T | |
| return output | |
| def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0): | |
| """Mixup data augmentation for time-domain signals. | |
| Args: | |
| x (torch.Tensor): Batch of time-domain signals, shape [batch, 1, time]. | |
| y (torch.Tensor): Batch of labels, shape [batch, n_classes]. | |
| alpha (float): Beta distribution parameter. | |
| Returns: | |
| torch.Tensor: Mixed time-domain signals, shape [batch, 1, time]. | |
| torch.Tensor: Mixed labels, shape [batch, n_classes]. | |
| torch.Tensor: Lambda | |
| """ | |
| batch_size = x.size(0) | |
| if alpha > 0: | |
| lam = np.random.beta(alpha, alpha) | |
| else: | |
| lam = 1 | |
| index = torch.randperm(batch_size).to(x.device) | |
| mixed_x = lam * x + (1 - lam) * x[index, :] | |
| mixed_y = lam * y + (1 - lam) * y[index, :] | |
| return mixed_x, mixed_y, lam | |
| class FXClassifier(pl.LightningModule): | |
| def __init__( | |
| self, | |
| lr: float, | |
| lr_weight_decay: float, | |
| sample_rate: float, | |
| network: nn.Module, | |
| mixup: bool = False, | |
| ): | |
| super().__init__() | |
| self.lr = lr | |
| self.lr_weight_decay = lr_weight_decay | |
| self.sample_rate = sample_rate | |
| self.network = network | |
| self.effects = ["Reverb", "Chorus", "Delay", "Distortion", "Compressor"] | |
| self.mixup = mixup | |
| self.train_f1 = torchmetrics.classification.MultilabelF1Score( | |
| 5, average="none", multidim_average="global" | |
| ) | |
| self.val_f1 = torchmetrics.classification.MultilabelF1Score( | |
| 5, average="none", multidim_average="global" | |
| ) | |
| self.test_f1 = torchmetrics.classification.MultilabelF1Score( | |
| 5, average="none", multidim_average="global" | |
| ) | |
| self.metrics = { | |
| "train": self.train_f1, | |
| "valid": self.val_f1, | |
| "test": self.test_f1, | |
| } | |
| def forward(self, x: torch.Tensor, train: bool = False): | |
| return self.network(x, train=train) | |
| def common_step(self, batch, batch_idx, mode: str = "train"): | |
| train = True if mode == "train" else False | |
| x, y, dry_label, wet_label = batch | |
| if mode == "train" and self.mixup: | |
| x_mixed, label_mixed, lam = mixup(x, wet_label) | |
| pred_label = self(x_mixed, train) | |
| loss = nn.functional.cross_entropy(pred_label, label_mixed) | |
| print(torch.sigmoid(pred_label[0, ...])) | |
| print(label_mixed[0, ...]) | |
| else: | |
| pred_label = self(x, train) | |
| loss = nn.functional.cross_entropy(pred_label, wet_label) | |
| print(torch.where(torch.sigmoid(pred_label[0, ...]) > 0.5, 1.0, 0.0).long()) | |
| print(wet_label.long()[0, ...]) | |
| self.log( | |
| f"{mode}_loss", | |
| loss, | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=True, | |
| logger=True, | |
| sync_dist=True, | |
| ) | |
| metrics = self.metrics[mode](torch.sigmoid(pred_label), wet_label.long()) | |
| avg_metrics = torch.mean(metrics) | |
| self.log( | |
| f"{mode}_f1_avg", | |
| avg_metrics, | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=True, | |
| logger=True, | |
| sync_dist=True, | |
| ) | |
| for idx, effect_name in enumerate(self.effects): | |
| self.log( | |
| f"{mode}_f1_{effect_name}", | |
| metrics[idx], | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=True, | |
| logger=True, | |
| sync_dist=True, | |
| ) | |
| return loss | |
| def training_step(self, batch, batch_idx): | |
| return self.common_step(batch, batch_idx, mode="train") | |
| def validation_step(self, batch, batch_idx): | |
| return self.common_step(batch, batch_idx, mode="valid") | |
| def test_step(self, batch, batch_idx): | |
| return self.common_step(batch, batch_idx, mode="test") | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.AdamW( | |
| self.network.parameters(), | |
| lr=self.lr, | |
| weight_decay=self.lr_weight_decay, | |
| ) | |
| return optimizer | |