import pytorch_lightning as pl import sys, gc import random import torch import torchaudio import typing as tp import wandb from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image from ema_pytorch import EMA from einops import rearrange from safetensors.torch import save_file from torch import optim from torch.nn import functional as F from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler from pytorch_lightning.utilities.rank_zero import rank_zero_only from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper from ..models.lm import AudioLMContinuousModelWrapper from .utils import create_optimizer_from_config, create_scheduler_from_config class AudioLMContinuousModelTrainingWrapper(pl.LightningModule): def __init__( self, model: AudioLanguageModelWrapper, lr = 1e-4, diffusion_objective: tp.Literal["rectified_flow", "v"] = "v", timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", use_ema=False, ema_copy=None, optimizer_configs: dict = None, diffusion_batch_mul=4, pre_encoded=False ): super().__init__() self.model = model self.diffusion = diffusion self.rng = torch.quasirandom.SobolEngine(1, scramble=True) self.model.pretransform.requires_grad_(False) self.timestep_sampler = timestep_sampler self.diffusion_objective = model.diffusion_objective loss_modules = [ MSELoss("v", "targets", weight=1.0, name="mse_loss" ) ] self.losses = MultiLoss(loss_modules) self.model_ema = None if use_ema: self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" if optimizer_configs is None: optimizer_configs = { "lm": { "optimizer": { "type": "AdamW", "config": { "lr": lr, "betas": (0.9, 0.95), "weight_decay": 0.1 } } } } else: if lr is not None: print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") self.optimizer_configs = optimizer_configs self.diffusion_batch_mul = diffusion_batch_mul self.pre_encoded = pre_encoded def configure_optimizers(self): lm_opt_config = self.optimizer_configs['lm'] opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters()) if "scheduler" in lm_opt_config: sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm) sched_lm_config = { "scheduler": sched_lm, "interval": "step" } return [opt_lm], [sched_lm_config] return [opt_lm] # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license # License can be found in LICENSES/LICENSE_META.txt def training_step(self, batch, batch_idx): reals, metadata = batch if reals.ndim == 4 and reals.shape[0] == 1: reals = reals[0] diffusion_input = reals loss_info = {} if not self.pre_encoded: loss_info["audio_reals"] = diffusion_input if self.diffusion.pretransform is not None: if not self.pre_encoded: with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): diffusion_input = self.diffusion.pretransform.encode(diffusion_input) else: # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: diffusion_input = diffusion_input / self.diffusion.pretransform.scale loss_info["reals"] = diffusion_input padding_masks = [] for md in metadata: if md["padding_mask"].ndim == 1: padding_masks.append(md["padding_mask"]) else: padding_masks.append(md["padding_mask"][0]) padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length) condition_tensors = None # If the model is conditioned, get the conditioning tensors if self.model.conditioner is not None: with torch.cuda.amp.autocast(): condition_tensors = self.model.conditioner(metadata, self.device) z = self.model.compute_logits(diffusion_input, condition_tensors=condition_tensors, cfg_dropout_prob=0.1) bsz, seq_len, _ = z.shape gt_inputs = diffusion_input.clone().detach() gt_inputs = gt_inputs.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1) mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul) if self.timestep_sampler == "uniform": # Draw uniformly distributed continuous timesteps t = self.rng.draw(z.shape[0])[:, 0].to(self.device) elif self.timestep_sampler == "logit_normal": t = torch.sigmoid(torch.randn(z.shape[0], device=self.device)) # Calculate the noise schedule parameters for those timesteps if self.diffusion_objective == "v": alphas, sigmas = get_alphas_sigmas(t) elif self.diffusion_objective == "rectified_flow": alphas, sigmas = 1-t, t # Combine the ground truth data and the noise alphas = alphas[:, None] sigmas = sigmas[:, None] noise = torch.randn_like(gt_inputs) noised_inputs = gt_inputs * alphas + noise * sigmas if self.diffusion_objective == "v": targets = noise * alphas - gt_inputs * sigmas elif self.diffusion_objective == "rectified_flow": targets = noise - gt_inputs cond = {} cond['z'] = z with torch.cuda.amp.autocast(): v = self.diffusion(noised_inputs, t, cond=cond) loss_info.update({ "v": v, "targets": targets }) loss, losses = self.losses() log_dict = { 'train/loss': loss.detach(), 'train/std_data': diffusion_input.std(), 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] } self.log_dict(log_dict, prog_bar=True, on_step=True) return loss def on_before_zero_grad(self, *args, **kwargs): if self.model_ema is not None: self.model_ema.update() def export_model(self, path, use_safetensors=False): model = self.model_ema.ema_model if self.model_ema is not None else self.model if use_safetensors: save_file(model.state_dict(), path) else: torch.save({"state_dict": model.state_dict()}, path) class AudioLanguageModelDemoCallback(pl.Callback):loss_info def __init__(self, demo_every=2000, num_demos=8, sample_size=65536, sample_rate=48000, demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], **kwargs ): super().__init__() self.demo_every = demo_every self.num_demos = num_demos self.demo_samples = sample_size self.sample_rate = sample_rate self.last_demo_step = -1 self.demo_conditioning = demo_conditioning self.demo_cfg_scales = demo_cfg_scales @rank_zero_only @torch.no_grad() def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: return module.eval() print(f"Generating demo") self.last_demo_step = trainer.global_step demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio #demo_reals = batch[0][:self.num_demos] # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: # demo_reals = demo_reals[0] #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals) ##Limit to first 50 tokens #demo_reals_tokens = demo_reals_tokens[:, :, :50] try: print("Getting conditioning") for cfg_scale in self.demo_cfg_scales: model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model print(f"Generating demo for cfg scale {cfg_scale}") fakes = model.generate_audio( batch_size=self.num_demos, max_gen_len=demo_length_tokens, conditioning=self.demo_conditioning, #init_data = demo_reals_tokens, cfg_scale=cfg_scale, temp=1.0, top_p=0.95 ) # Put the demos together fakes = rearrange(fakes, 'b d n -> d (b n)') log_dict = {} filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' fakes = fakes / fakes.abs().max() fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu() torchaudio.save(filename, fakes, self.sample_rate) log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, sample_rate=self.sample_rate, caption=f'Reconstructed') log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) trainer.logger.experiment.log(log_dict) except Exception as e: raise e finally: gc.collect() torch.cuda.empty_cache() module.train()