import torch import math import numpy as np from torch import nn, sin, pow from torch.nn import functional as F from torch.nn import Parameter from torchaudio import transforms as T from alias_free_torch import Activation1d from dac.nn.layers import WNConv1d, WNConvTranspose1d from typing import List, Literal, Dict, Any, Callable from einops import rearrange from ...inference.sampling import sample from ...inference.utils import prepare_audio from .bottleneck import Bottleneck from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper from .factory import create_pretransform_from_config, create_bottleneck_from_config from .pretransforms import Pretransform, AutoencoderPretransform def snake_beta(x, alpha, beta): return x + (1.0 / (beta + 0.000000001)) * pow(sin(x * alpha), 2) try: snake_beta = torch.compile(snake_beta) except RuntimeError: pass # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license # License available in LICENSES/LICENSE_NVIDIA.txt class SnakeBeta(nn.Module): def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): super(SnakeBeta, self).__init__() self.in_features = in_features # initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = Parameter(torch.zeros(in_features) * alpha) self.beta = Parameter(torch.zeros(in_features) * alpha) else: # linear scale alphas initialized to ones self.alpha = Parameter(torch.ones(in_features) * alpha) self.beta = Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.beta.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x): alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] beta = self.beta.unsqueeze(0).unsqueeze(-1) if self.alpha_logscale: alpha = torch.exp(alpha) beta = torch.exp(beta) x = snake_beta(x, alpha, beta) return x def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: if activation == "elu": act = nn.ELU() elif activation == "snake": act = SnakeBeta(channels) elif activation == "none": act = nn.Identity() else: raise ValueError(f"Unknown activation {activation}") if antialias: act = Activation1d(act) return act class ResidualUnit(nn.Module): def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): super().__init__() self.dilation = dilation act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels) padding = (dilation * (7-1)) // 2 self.layers = nn.Sequential( act, WNConv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, dilation=dilation, padding=padding), act, WNConv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=1) ) def forward(self, x): return x + self.layers(x) class EncoderBlock(nn.Module): def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): super().__init__() act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels) self.layers = nn.Sequential( ResidualUnit(in_channels=in_channels, out_channels=in_channels, dilation=1, use_snake=use_snake), ResidualUnit(in_channels=in_channels, out_channels=in_channels, dilation=3, use_snake=use_snake), ResidualUnit(in_channels=in_channels, out_channels=in_channels, dilation=9, use_snake=use_snake), act, WNConv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), ) def forward(self, x): return self.layers(x) class DecoderBlock(nn.Module): def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): super().__init__() if use_nearest_upsample: upsample_layer = nn.Sequential( nn.Upsample(scale_factor=stride, mode="nearest"), WNConv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=2*stride, stride=1, bias=False, padding='same') ) else: upsample_layer = WNConvTranspose1d(in_channels=in_channels, out_channels=out_channels, kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels) self.layers = nn.Sequential( act, upsample_layer, ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=1, use_snake=use_snake), ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=3, use_snake=use_snake), ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=9, use_snake=use_snake), ) def forward(self, x): return self.layers(x) class OobleckEncoder(nn.Module): def __init__(self, in_channels=2, channels=128, latent_dim=32, c_mults = [1, 2, 4, 8], strides = [2, 4, 8, 8], use_snake=False, antialias_activation=False ): super().__init__() c_mults = [1] + c_mults self.depth = len(c_mults) layers = [ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) ] for i in range(self.depth-1): layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] layers += [ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) ] self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) class OobleckDecoder(nn.Module): def __init__(self, out_channels=2, channels=128, latent_dim=32, c_mults = [1, 2, 4, 8], strides = [2, 4, 8, 8], use_snake=False, antialias_activation=False, use_nearest_upsample=False, final_tanh=True): super().__init__() c_mults = [1] + c_mults self.depth = len(c_mults) layers = [ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), ] for i in range(self.depth-1, 0, -1): layers += [DecoderBlock( in_channels=c_mults[i]*channels, out_channels=c_mults[i-1]*channels, stride=strides[i-1], use_snake=use_snake, antialias_activation=antialias_activation, use_nearest_upsample=use_nearest_upsample ) ] layers += [ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), nn.Tanh() if final_tanh else nn.Identity() ] self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) class DACEncoderWrapper(nn.Module): def __init__(self, in_channels=1, **kwargs): super().__init__() from dac.model.dac import Encoder as DACEncoder latent_dim = kwargs.pop("latent_dim", None) encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"])) self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs) self.latent_dim = latent_dim # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() if in_channels != 1: self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3) def forward(self, x): x = self.encoder(x) x = self.proj_out(x) return x class DACDecoderWrapper(nn.Module): def __init__(self, latent_dim, out_channels=1, **kwargs): super().__init__() from dac.model.dac import Decoder as DACDecoder self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels) self.latent_dim = latent_dim def forward(self, x): return self.decoder(x) class AudioAutoencoder(nn.Module): def __init__( self, encoder, decoder, latent_dim, downsampling_ratio, sample_rate, io_channels=2, bottleneck: Bottleneck = None, pretransform: Pretransform = None, in_channels = None, out_channels = None, soft_clip = False ): super().__init__() self.downsampling_ratio = downsampling_ratio self.sample_rate = sample_rate self.latent_dim = latent_dim self.io_channels = io_channels self.in_channels = io_channels self.out_channels = io_channels self.min_length = self.downsampling_ratio if in_channels is not None: self.in_channels = in_channels if out_channels is not None: self.out_channels = out_channels self.bottleneck = bottleneck self.encoder = encoder self.decoder = decoder self.pretransform = pretransform self.soft_clip = soft_clip def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): info = {} if self.pretransform is not None and not skip_pretransform: if self.pretransform.enable_grad: if iterate_batch: audios = [] for i in range(audio.shape[0]): audios.append(self.pretransform.encode(audio[i:i+1])) audio = torch.cat(audios, dim=0) else: audio = self.pretransform.encode(audio) else: with torch.no_grad(): if iterate_batch: audios = [] for i in range(audio.shape[0]): audios.append(self.pretransform.encode(audio[i:i+1])) audio = torch.cat(audios, dim=0) else: audio = self.pretransform.encode(audio) if self.encoder is not None: if iterate_batch: latents = [] for i in range(audio.shape[0]): latents.append(self.encoder(audio[i:i+1])) latents = torch.cat(latents, dim=0) else: latents = self.encoder(audio) else: latents = audio if self.bottleneck is not None: # TODO: Add iterate batch logic, needs to merge the info dicts latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs) info.update(bottleneck_info) if return_info: return latents, info return latents def decode(self, latents, iterate_batch=False, **kwargs): if self.bottleneck is not None: if iterate_batch: decoded = [] for i in range(latents.shape[0]): decoded.append(self.bottleneck.decode(latents[i:i+1])) decoded = torch.cat(decoded, dim=0) else: latents = self.bottleneck.decode(latents) if iterate_batch: decoded = [] for i in range(latents.shape[0]): decoded.append(self.decoder(latents[i:i+1])) decoded = torch.cat(decoded, dim=0) else: decoded = self.decoder(latents, **kwargs) if self.pretransform is not None: if self.pretransform.enable_grad: if iterate_batch: decodeds = [] for i in range(decoded.shape[0]): decodeds.append(self.pretransform.decode(decoded[i:i+1])) decoded = torch.cat(decodeds, dim=0) else: decoded = self.pretransform.decode(decoded) else: with torch.no_grad(): if iterate_batch: decodeds = [] for i in range(latents.shape[0]): decodeds.append(self.pretransform.decode(decoded[i:i+1])) decoded = torch.cat(decodeds, dim=0) else: decoded = self.pretransform.decode(decoded) if self.soft_clip: decoded = torch.tanh(decoded) return decoded def encode_audio(self, audio, in_sr, **kwargs): ''' Encode single audio tensor to latents, including preprocessing the audio to be compatible with the model ''' if in_sr != self.sample_rate: resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device) audio = resample_tf(audio) audio_length = audio.shape[-1] pad_length = (self.min_length - (audio_length % self.min_length)) % self.min_length # Pad with zeros to multiple of model's downsampling ratio audio = F.pad(audio, (0, pad_length)) audio = prepare_audio(audio, in_sr=self.sample_rate, target_sr=self.sample_rate, target_length=audio.shape[1], target_channels=self.in_channels, device=audio.device) # TODO: Add chunking logic return self.encode(audio, **kwargs) def decode_audio(self, latents, **kwargs): ''' Decode latents to audio ''' # TODO: Add chunking logic return self.decode(latents, **kwargs) class DiffusionAutoencoder(AudioAutoencoder): def __init__( self, diffusion: ConditionedDiffusionModel, diffusion_downsampling_ratio, *args, **kwargs ): super().__init__(*args, **kwargs) self.diffusion = diffusion self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio if self.encoder is not None: # Shrink the initial encoder parameters to avoid saturated latents with torch.no_grad(): for param in self.encoder.parameters(): param *= 0.5 def decode(self, latents, steps=100): upsampled_length = latents.shape[2] * self.downsampling_ratio if self.bottleneck is not None: latents = self.bottleneck.decode(latents) if self.decoder is not None: latents = self.decode(latents) # Upsample latents to match diffusion length if latents.shape[2] != upsampled_length: latents = F.interpolate(latents, size=upsampled_length, mode='nearest') noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device) decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents) if self.pretransform is not None: if self.pretransform.enable_grad: decoded = self.pretransform.decode(decoded) else: with torch.no_grad(): decoded = self.pretransform.decode(decoded) return decoded # AE factories def create_encoder_from_config(encoder_config: Dict[str, Any]): encoder_type = encoder_config.get("type", None) assert encoder_type is not None, "Encoder type must be specified" if encoder_type == "oobleck": encoder = OobleckEncoder( **encoder_config["config"] ) elif encoder_type == "seanet": from encodec.modules import SEANetEncoder seanet_encoder_config = encoder_config["config"] #SEANet encoder expects strides in reverse order seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))) encoder = SEANetEncoder( **seanet_encoder_config ) elif encoder_type == "dac": dac_config = encoder_config["config"] encoder = DACEncoderWrapper(**dac_config) elif encoder_type == "local_attn": from .local_attention import TransformerEncoder1D local_attn_config = encoder_config["config"] encoder = TransformerEncoder1D( **local_attn_config ) else: raise ValueError(f"Unknown encoder type {encoder_type}") requires_grad = encoder_config.get("requires_grad", True) if not requires_grad: for param in encoder.parameters(): param.requires_grad = False return encoder def create_decoder_from_config(decoder_config: Dict[str, Any]): decoder_type = decoder_config.get("type", None) assert decoder_type is not None, "Decoder type must be specified" if decoder_type == "oobleck": decoder = OobleckDecoder( **decoder_config["config"] ) elif decoder_type == "seanet": from encodec.modules import SEANetDecoder decoder = SEANetDecoder( **decoder_config["config"] ) elif decoder_type == "dac": dac_config = decoder_config["config"] decoder = DACDecoderWrapper(**dac_config) elif decoder_type == "local_attn": from .local_attention import TransformerDecoder1D local_attn_config = decoder_config["config"] decoder = TransformerDecoder1D( **local_attn_config ) else: raise ValueError(f"Unknown decoder type {decoder_type}") requires_grad = decoder_config.get("requires_grad", True) if not requires_grad: for param in decoder.parameters(): param.requires_grad = False return decoder def create_autoencoder_from_config(config: Dict[str, Any]): ae_config = config["model"] encoder = create_encoder_from_config(ae_config["encoder"]) decoder = create_decoder_from_config(ae_config["decoder"]) bottleneck = ae_config.get("bottleneck", None) latent_dim = ae_config.get("latent_dim", None) assert latent_dim is not None, "latent_dim must be specified in model config" downsampling_ratio = ae_config.get("downsampling_ratio", None) assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" io_channels = ae_config.get("io_channels", None) assert io_channels is not None, "io_channels must be specified in model config" sample_rate = config.get("sample_rate", None) assert sample_rate is not None, "sample_rate must be specified in model config" in_channels = ae_config.get("in_channels", None) out_channels = ae_config.get("out_channels", None) pretransform = ae_config.get("pretransform", None) if pretransform is not None: pretransform = create_pretransform_from_config(pretransform, sample_rate) if bottleneck is not None: bottleneck = create_bottleneck_from_config(bottleneck) soft_clip = ae_config["decoder"].get("soft_clip", False) return AudioAutoencoder( encoder, decoder, io_channels=io_channels, latent_dim=latent_dim, downsampling_ratio=downsampling_ratio, sample_rate=sample_rate, bottleneck=bottleneck, pretransform=pretransform, in_channels=in_channels, out_channels=out_channels, soft_clip=soft_clip ) def create_diffAE_from_config(config: Dict[str, Any]): diffae_config = config["model"] if "encoder" in diffae_config: encoder = create_encoder_from_config(diffae_config["encoder"]) else: encoder = None if "decoder" in diffae_config: decoder = create_decoder_from_config(diffae_config["decoder"]) else: decoder = None diffusion_model_type = diffae_config["diffusion"]["type"] if diffusion_model_type == "DAU1d": diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"]) elif diffusion_model_type == "adp_1d": diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"]) elif diffusion_model_type == "dit": diffusion = DiTWrapper(**diffae_config["diffusion"]["config"]) latent_dim = diffae_config.get("latent_dim", None) assert latent_dim is not None, "latent_dim must be specified in model config" downsampling_ratio = diffae_config.get("downsampling_ratio", None) assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" io_channels = diffae_config.get("io_channels", None) assert io_channels is not None, "io_channels must be specified in model config" sample_rate = config.get("sample_rate", None) assert sample_rate is not None, "sample_rate must be specified in model config" bottleneck = diffae_config.get("bottleneck", None) pretransform = diffae_config.get("pretransform", None) if pretransform is not None: pretransform = create_pretransform_from_config(pretransform, sample_rate) if bottleneck is not None: bottleneck = create_bottleneck_from_config(bottleneck) diffusion_downsampling_ratio = None, if diffusion_model_type == "DAU1d": diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"]) elif diffusion_model_type == "adp_1d": diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"]) elif diffusion_model_type == "dit": diffusion_downsampling_ratio = 1 return DiffusionAutoencoder( encoder=encoder, decoder=decoder, diffusion=diffusion, io_channels=io_channels, sample_rate=sample_rate, latent_dim=latent_dim, downsampling_ratio=downsampling_ratio, diffusion_downsampling_ratio=diffusion_downsampling_ratio, bottleneck=bottleneck, pretransform=pretransform )