Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import math | |
import numpy as np | |
from torch import nn, sin, pow | |
from torch.nn import functional as F | |
from torch.nn import Parameter | |
from torchaudio import transforms as T | |
from alias_free_torch import Activation1d | |
from dac.nn.layers import WNConv1d, WNConvTranspose1d | |
from typing import List, Literal, Dict, Any, Callable | |
from einops import rearrange | |
from ...inference.sampling import sample | |
from ...inference.utils import prepare_audio | |
from .bottleneck import Bottleneck | |
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper | |
from .factory import create_pretransform_from_config, create_bottleneck_from_config | |
from .pretransforms import Pretransform, AutoencoderPretransform | |
def snake_beta(x, alpha, beta): | |
return x + (1.0 / (beta + 0.000000001)) * pow(sin(x * alpha), 2) | |
try: | |
snake_beta = torch.compile(snake_beta) | |
except RuntimeError: | |
pass | |
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license | |
# License available in LICENSES/LICENSE_NVIDIA.txt | |
class SnakeBeta(nn.Module): | |
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): | |
super(SnakeBeta, self).__init__() | |
self.in_features = in_features | |
# initialize alpha | |
self.alpha_logscale = alpha_logscale | |
if self.alpha_logscale: # log scale alphas initialized to zeros | |
self.alpha = Parameter(torch.zeros(in_features) * alpha) | |
self.beta = Parameter(torch.zeros(in_features) * alpha) | |
else: # linear scale alphas initialized to ones | |
self.alpha = Parameter(torch.ones(in_features) * alpha) | |
self.beta = Parameter(torch.ones(in_features) * alpha) | |
self.alpha.requires_grad = alpha_trainable | |
self.beta.requires_grad = alpha_trainable | |
self.no_div_by_zero = 0.000000001 | |
def forward(self, x): | |
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] | |
beta = self.beta.unsqueeze(0).unsqueeze(-1) | |
if self.alpha_logscale: | |
alpha = torch.exp(alpha) | |
beta = torch.exp(beta) | |
x = snake_beta(x, alpha, beta) | |
return x | |
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: | |
if activation == "elu": | |
act = nn.ELU() | |
elif activation == "snake": | |
act = SnakeBeta(channels) | |
elif activation == "none": | |
act = nn.Identity() | |
else: | |
raise ValueError(f"Unknown activation {activation}") | |
if antialias: | |
act = Activation1d(act) | |
return act | |
class ResidualUnit(nn.Module): | |
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): | |
super().__init__() | |
self.dilation = dilation | |
act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels) | |
padding = (dilation * (7-1)) // 2 | |
self.layers = nn.Sequential( | |
act, | |
WNConv1d(in_channels=in_channels, out_channels=out_channels, | |
kernel_size=7, dilation=dilation, padding=padding), | |
act, | |
WNConv1d(in_channels=out_channels, out_channels=out_channels, | |
kernel_size=1) | |
) | |
def forward(self, x): | |
return x + self.layers(x) | |
class EncoderBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): | |
super().__init__() | |
act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels) | |
self.layers = nn.Sequential( | |
ResidualUnit(in_channels=in_channels, | |
out_channels=in_channels, dilation=1, use_snake=use_snake), | |
ResidualUnit(in_channels=in_channels, | |
out_channels=in_channels, dilation=3, use_snake=use_snake), | |
ResidualUnit(in_channels=in_channels, | |
out_channels=in_channels, dilation=9, use_snake=use_snake), | |
act, | |
WNConv1d(in_channels=in_channels, out_channels=out_channels, | |
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), | |
) | |
def forward(self, x): | |
return self.layers(x) | |
class DecoderBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): | |
super().__init__() | |
if use_nearest_upsample: | |
upsample_layer = nn.Sequential( | |
nn.Upsample(scale_factor=stride, mode="nearest"), | |
WNConv1d(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=2*stride, | |
stride=1, | |
bias=False, | |
padding='same') | |
) | |
else: | |
upsample_layer = WNConvTranspose1d(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) | |
act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels) | |
self.layers = nn.Sequential( | |
act, | |
upsample_layer, | |
ResidualUnit(in_channels=out_channels, out_channels=out_channels, | |
dilation=1, use_snake=use_snake), | |
ResidualUnit(in_channels=out_channels, out_channels=out_channels, | |
dilation=3, use_snake=use_snake), | |
ResidualUnit(in_channels=out_channels, out_channels=out_channels, | |
dilation=9, use_snake=use_snake), | |
) | |
def forward(self, x): | |
return self.layers(x) | |
class OobleckEncoder(nn.Module): | |
def __init__(self, | |
in_channels=2, | |
channels=128, | |
latent_dim=32, | |
c_mults = [1, 2, 4, 8], | |
strides = [2, 4, 8, 8], | |
use_snake=False, | |
antialias_activation=False | |
): | |
super().__init__() | |
c_mults = [1] + c_mults | |
self.depth = len(c_mults) | |
layers = [ | |
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) | |
] | |
for i in range(self.depth-1): | |
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] | |
layers += [ | |
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), | |
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) | |
] | |
self.layers = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.layers(x) | |
class OobleckDecoder(nn.Module): | |
def __init__(self, | |
out_channels=2, | |
channels=128, | |
latent_dim=32, | |
c_mults = [1, 2, 4, 8], | |
strides = [2, 4, 8, 8], | |
use_snake=False, | |
antialias_activation=False, | |
use_nearest_upsample=False, | |
final_tanh=True): | |
super().__init__() | |
c_mults = [1] + c_mults | |
self.depth = len(c_mults) | |
layers = [ | |
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), | |
] | |
for i in range(self.depth-1, 0, -1): | |
layers += [DecoderBlock( | |
in_channels=c_mults[i]*channels, | |
out_channels=c_mults[i-1]*channels, | |
stride=strides[i-1], | |
use_snake=use_snake, | |
antialias_activation=antialias_activation, | |
use_nearest_upsample=use_nearest_upsample | |
) | |
] | |
layers += [ | |
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), | |
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), | |
nn.Tanh() if final_tanh else nn.Identity() | |
] | |
self.layers = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.layers(x) | |
class DACEncoderWrapper(nn.Module): | |
def __init__(self, in_channels=1, **kwargs): | |
super().__init__() | |
from dac.model.dac import Encoder as DACEncoder | |
latent_dim = kwargs.pop("latent_dim", None) | |
encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"])) | |
self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs) | |
self.latent_dim = latent_dim | |
# Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility | |
self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() | |
if in_channels != 1: | |
self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3) | |
def forward(self, x): | |
x = self.encoder(x) | |
x = self.proj_out(x) | |
return x | |
class DACDecoderWrapper(nn.Module): | |
def __init__(self, latent_dim, out_channels=1, **kwargs): | |
super().__init__() | |
from dac.model.dac import Decoder as DACDecoder | |
self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels) | |
self.latent_dim = latent_dim | |
def forward(self, x): | |
return self.decoder(x) | |
class AudioAutoencoder(nn.Module): | |
def __init__( | |
self, | |
encoder, | |
decoder, | |
latent_dim, | |
downsampling_ratio, | |
sample_rate, | |
io_channels=2, | |
bottleneck: Bottleneck = None, | |
pretransform: Pretransform = None, | |
in_channels = None, | |
out_channels = None, | |
soft_clip = False | |
): | |
super().__init__() | |
self.downsampling_ratio = downsampling_ratio | |
self.sample_rate = sample_rate | |
self.latent_dim = latent_dim | |
self.io_channels = io_channels | |
self.in_channels = io_channels | |
self.out_channels = io_channels | |
self.min_length = self.downsampling_ratio | |
if in_channels is not None: | |
self.in_channels = in_channels | |
if out_channels is not None: | |
self.out_channels = out_channels | |
self.bottleneck = bottleneck | |
self.encoder = encoder | |
self.decoder = decoder | |
self.pretransform = pretransform | |
self.soft_clip = soft_clip | |
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): | |
info = {} | |
if self.pretransform is not None and not skip_pretransform: | |
if self.pretransform.enable_grad: | |
if iterate_batch: | |
audios = [] | |
for i in range(audio.shape[0]): | |
audios.append(self.pretransform.encode(audio[i:i+1])) | |
audio = torch.cat(audios, dim=0) | |
else: | |
audio = self.pretransform.encode(audio) | |
else: | |
with torch.no_grad(): | |
if iterate_batch: | |
audios = [] | |
for i in range(audio.shape[0]): | |
audios.append(self.pretransform.encode(audio[i:i+1])) | |
audio = torch.cat(audios, dim=0) | |
else: | |
audio = self.pretransform.encode(audio) | |
if self.encoder is not None: | |
if iterate_batch: | |
latents = [] | |
for i in range(audio.shape[0]): | |
latents.append(self.encoder(audio[i:i+1])) | |
latents = torch.cat(latents, dim=0) | |
else: | |
latents = self.encoder(audio) | |
else: | |
latents = audio | |
if self.bottleneck is not None: | |
# TODO: Add iterate batch logic, needs to merge the info dicts | |
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs) | |
info.update(bottleneck_info) | |
if return_info: | |
return latents, info | |
return latents | |
def decode(self, latents, iterate_batch=False, **kwargs): | |
if self.bottleneck is not None: | |
if iterate_batch: | |
decoded = [] | |
for i in range(latents.shape[0]): | |
decoded.append(self.bottleneck.decode(latents[i:i+1])) | |
decoded = torch.cat(decoded, dim=0) | |
else: | |
latents = self.bottleneck.decode(latents) | |
if iterate_batch: | |
decoded = [] | |
for i in range(latents.shape[0]): | |
decoded.append(self.decoder(latents[i:i+1])) | |
decoded = torch.cat(decoded, dim=0) | |
else: | |
decoded = self.decoder(latents, **kwargs) | |
if self.pretransform is not None: | |
if self.pretransform.enable_grad: | |
if iterate_batch: | |
decodeds = [] | |
for i in range(decoded.shape[0]): | |
decodeds.append(self.pretransform.decode(decoded[i:i+1])) | |
decoded = torch.cat(decodeds, dim=0) | |
else: | |
decoded = self.pretransform.decode(decoded) | |
else: | |
with torch.no_grad(): | |
if iterate_batch: | |
decodeds = [] | |
for i in range(latents.shape[0]): | |
decodeds.append(self.pretransform.decode(decoded[i:i+1])) | |
decoded = torch.cat(decodeds, dim=0) | |
else: | |
decoded = self.pretransform.decode(decoded) | |
if self.soft_clip: | |
decoded = torch.tanh(decoded) | |
return decoded | |
def encode_audio(self, audio, in_sr, **kwargs): | |
''' | |
Encode single audio tensor to latents, including preprocessing the audio to be compatible with the model | |
''' | |
if in_sr != self.sample_rate: | |
resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device) | |
audio = resample_tf(audio) | |
audio_length = audio.shape[-1] | |
pad_length = (self.min_length - (audio_length % self.min_length)) % self.min_length | |
# Pad with zeros to multiple of model's downsampling ratio | |
audio = F.pad(audio, (0, pad_length)) | |
audio = prepare_audio(audio, in_sr=self.sample_rate, target_sr=self.sample_rate, target_length=audio.shape[1], target_channels=self.in_channels, device=audio.device) | |
# TODO: Add chunking logic | |
return self.encode(audio, **kwargs) | |
def decode_audio(self, latents, **kwargs): | |
''' | |
Decode latents to audio | |
''' | |
# TODO: Add chunking logic | |
return self.decode(latents, **kwargs) | |
class DiffusionAutoencoder(AudioAutoencoder): | |
def __init__( | |
self, | |
diffusion: ConditionedDiffusionModel, | |
diffusion_downsampling_ratio, | |
*args, | |
**kwargs | |
): | |
super().__init__(*args, **kwargs) | |
self.diffusion = diffusion | |
self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio | |
if self.encoder is not None: | |
# Shrink the initial encoder parameters to avoid saturated latents | |
with torch.no_grad(): | |
for param in self.encoder.parameters(): | |
param *= 0.5 | |
def decode(self, latents, steps=100): | |
upsampled_length = latents.shape[2] * self.downsampling_ratio | |
if self.bottleneck is not None: | |
latents = self.bottleneck.decode(latents) | |
if self.decoder is not None: | |
latents = self.decode(latents) | |
# Upsample latents to match diffusion length | |
if latents.shape[2] != upsampled_length: | |
latents = F.interpolate(latents, size=upsampled_length, mode='nearest') | |
noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device) | |
decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents) | |
if self.pretransform is not None: | |
if self.pretransform.enable_grad: | |
decoded = self.pretransform.decode(decoded) | |
else: | |
with torch.no_grad(): | |
decoded = self.pretransform.decode(decoded) | |
return decoded | |
# AE factories | |
def create_encoder_from_config(encoder_config: Dict[str, Any]): | |
encoder_type = encoder_config.get("type", None) | |
assert encoder_type is not None, "Encoder type must be specified" | |
if encoder_type == "oobleck": | |
encoder = OobleckEncoder( | |
**encoder_config["config"] | |
) | |
elif encoder_type == "seanet": | |
from encodec.modules import SEANetEncoder | |
seanet_encoder_config = encoder_config["config"] | |
#SEANet encoder expects strides in reverse order | |
seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))) | |
encoder = SEANetEncoder( | |
**seanet_encoder_config | |
) | |
elif encoder_type == "dac": | |
dac_config = encoder_config["config"] | |
encoder = DACEncoderWrapper(**dac_config) | |
elif encoder_type == "local_attn": | |
from .local_attention import TransformerEncoder1D | |
local_attn_config = encoder_config["config"] | |
encoder = TransformerEncoder1D( | |
**local_attn_config | |
) | |
else: | |
raise ValueError(f"Unknown encoder type {encoder_type}") | |
requires_grad = encoder_config.get("requires_grad", True) | |
if not requires_grad: | |
for param in encoder.parameters(): | |
param.requires_grad = False | |
return encoder | |
def create_decoder_from_config(decoder_config: Dict[str, Any]): | |
decoder_type = decoder_config.get("type", None) | |
assert decoder_type is not None, "Decoder type must be specified" | |
if decoder_type == "oobleck": | |
decoder = OobleckDecoder( | |
**decoder_config["config"] | |
) | |
elif decoder_type == "seanet": | |
from encodec.modules import SEANetDecoder | |
decoder = SEANetDecoder( | |
**decoder_config["config"] | |
) | |
elif decoder_type == "dac": | |
dac_config = decoder_config["config"] | |
decoder = DACDecoderWrapper(**dac_config) | |
elif decoder_type == "local_attn": | |
from .local_attention import TransformerDecoder1D | |
local_attn_config = decoder_config["config"] | |
decoder = TransformerDecoder1D( | |
**local_attn_config | |
) | |
else: | |
raise ValueError(f"Unknown decoder type {decoder_type}") | |
requires_grad = decoder_config.get("requires_grad", True) | |
if not requires_grad: | |
for param in decoder.parameters(): | |
param.requires_grad = False | |
return decoder | |
def create_autoencoder_from_config(config: Dict[str, Any]): | |
ae_config = config["model"] | |
encoder = create_encoder_from_config(ae_config["encoder"]) | |
decoder = create_decoder_from_config(ae_config["decoder"]) | |
bottleneck = ae_config.get("bottleneck", None) | |
latent_dim = ae_config.get("latent_dim", None) | |
assert latent_dim is not None, "latent_dim must be specified in model config" | |
downsampling_ratio = ae_config.get("downsampling_ratio", None) | |
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" | |
io_channels = ae_config.get("io_channels", None) | |
assert io_channels is not None, "io_channels must be specified in model config" | |
sample_rate = config.get("sample_rate", None) | |
assert sample_rate is not None, "sample_rate must be specified in model config" | |
in_channels = ae_config.get("in_channels", None) | |
out_channels = ae_config.get("out_channels", None) | |
pretransform = ae_config.get("pretransform", None) | |
if pretransform is not None: | |
pretransform = create_pretransform_from_config(pretransform, sample_rate) | |
if bottleneck is not None: | |
bottleneck = create_bottleneck_from_config(bottleneck) | |
soft_clip = ae_config["decoder"].get("soft_clip", False) | |
return AudioAutoencoder( | |
encoder, | |
decoder, | |
io_channels=io_channels, | |
latent_dim=latent_dim, | |
downsampling_ratio=downsampling_ratio, | |
sample_rate=sample_rate, | |
bottleneck=bottleneck, | |
pretransform=pretransform, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
soft_clip=soft_clip | |
) | |
def create_diffAE_from_config(config: Dict[str, Any]): | |
diffae_config = config["model"] | |
if "encoder" in diffae_config: | |
encoder = create_encoder_from_config(diffae_config["encoder"]) | |
else: | |
encoder = None | |
if "decoder" in diffae_config: | |
decoder = create_decoder_from_config(diffae_config["decoder"]) | |
else: | |
decoder = None | |
diffusion_model_type = diffae_config["diffusion"]["type"] | |
if diffusion_model_type == "DAU1d": | |
diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"]) | |
elif diffusion_model_type == "adp_1d": | |
diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"]) | |
elif diffusion_model_type == "dit": | |
diffusion = DiTWrapper(**diffae_config["diffusion"]["config"]) | |
latent_dim = diffae_config.get("latent_dim", None) | |
assert latent_dim is not None, "latent_dim must be specified in model config" | |
downsampling_ratio = diffae_config.get("downsampling_ratio", None) | |
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" | |
io_channels = diffae_config.get("io_channels", None) | |
assert io_channels is not None, "io_channels must be specified in model config" | |
sample_rate = config.get("sample_rate", None) | |
assert sample_rate is not None, "sample_rate must be specified in model config" | |
bottleneck = diffae_config.get("bottleneck", None) | |
pretransform = diffae_config.get("pretransform", None) | |
if pretransform is not None: | |
pretransform = create_pretransform_from_config(pretransform, sample_rate) | |
if bottleneck is not None: | |
bottleneck = create_bottleneck_from_config(bottleneck) | |
diffusion_downsampling_ratio = None, | |
if diffusion_model_type == "DAU1d": | |
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"]) | |
elif diffusion_model_type == "adp_1d": | |
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"]) | |
elif diffusion_model_type == "dit": | |
diffusion_downsampling_ratio = 1 | |
return DiffusionAutoencoder( | |
encoder=encoder, | |
decoder=decoder, | |
diffusion=diffusion, | |
io_channels=io_channels, | |
sample_rate=sample_rate, | |
latent_dim=latent_dim, | |
downsampling_ratio=downsampling_ratio, | |
diffusion_downsampling_ratio=diffusion_downsampling_ratio, | |
bottleneck=bottleneck, | |
pretransform=pretransform | |
) |