mrfakename's picture
Upload 114 files
c8448bc verified
raw
history blame
23.4 kB
import torch
import math
import numpy as np
from torch import nn, sin, pow
from torch.nn import functional as F
from torch.nn import Parameter
from torchaudio import transforms as T
from alias_free_torch import Activation1d
from dac.nn.layers import WNConv1d, WNConvTranspose1d
from typing import List, Literal, Dict, Any, Callable
from einops import rearrange
from ...inference.sampling import sample
from ...inference.utils import prepare_audio
from .bottleneck import Bottleneck
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
from .factory import create_pretransform_from_config, create_bottleneck_from_config
from .pretransforms import Pretransform, AutoencoderPretransform
def snake_beta(x, alpha, beta):
return x + (1.0 / (beta + 0.000000001)) * pow(sin(x * alpha), 2)
try:
snake_beta = torch.compile(snake_beta)
except RuntimeError:
pass
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
# License available in LICENSES/LICENSE_NVIDIA.txt
class SnakeBeta(nn.Module):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
self.beta = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.beta = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = snake_beta(x, alpha, beta)
return x
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu":
act = nn.ELU()
elif activation == "snake":
act = SnakeBeta(channels)
elif activation == "none":
act = nn.Identity()
else:
raise ValueError(f"Unknown activation {activation}")
if antialias:
act = Activation1d(act)
return act
class ResidualUnit(nn.Module):
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
super().__init__()
self.dilation = dilation
act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels)
padding = (dilation * (7-1)) // 2
self.layers = nn.Sequential(
act,
WNConv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=7, dilation=dilation, padding=padding),
act,
WNConv1d(in_channels=out_channels, out_channels=out_channels,
kernel_size=1)
)
def forward(self, x):
return x + self.layers(x)
class EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
super().__init__()
act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels)
self.layers = nn.Sequential(
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=1, use_snake=use_snake),
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=3, use_snake=use_snake),
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=9, use_snake=use_snake),
act,
WNConv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
)
def forward(self, x):
return self.layers(x)
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
super().__init__()
if use_nearest_upsample:
upsample_layer = nn.Sequential(
nn.Upsample(scale_factor=stride, mode="nearest"),
WNConv1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=2*stride,
stride=1,
bias=False,
padding='same')
)
else:
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
act = get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels)
self.layers = nn.Sequential(
act,
upsample_layer,
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=1, use_snake=use_snake),
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=3, use_snake=use_snake),
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=9, use_snake=use_snake),
)
def forward(self, x):
return self.layers(x)
class OobleckEncoder(nn.Module):
def __init__(self,
in_channels=2,
channels=128,
latent_dim=32,
c_mults = [1, 2, 4, 8],
strides = [2, 4, 8, 8],
use_snake=False,
antialias_activation=False
):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
]
for i in range(self.depth-1):
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
layers += [
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class OobleckDecoder(nn.Module):
def __init__(self,
out_channels=2,
channels=128,
latent_dim=32,
c_mults = [1, 2, 4, 8],
strides = [2, 4, 8, 8],
use_snake=False,
antialias_activation=False,
use_nearest_upsample=False,
final_tanh=True):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
]
for i in range(self.depth-1, 0, -1):
layers += [DecoderBlock(
in_channels=c_mults[i]*channels,
out_channels=c_mults[i-1]*channels,
stride=strides[i-1],
use_snake=use_snake,
antialias_activation=antialias_activation,
use_nearest_upsample=use_nearest_upsample
)
]
layers += [
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
nn.Tanh() if final_tanh else nn.Identity()
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class DACEncoderWrapper(nn.Module):
def __init__(self, in_channels=1, **kwargs):
super().__init__()
from dac.model.dac import Encoder as DACEncoder
latent_dim = kwargs.pop("latent_dim", None)
encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
self.latent_dim = latent_dim
# Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
if in_channels != 1:
self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
def forward(self, x):
x = self.encoder(x)
x = self.proj_out(x)
return x
class DACDecoderWrapper(nn.Module):
def __init__(self, latent_dim, out_channels=1, **kwargs):
super().__init__()
from dac.model.dac import Decoder as DACDecoder
self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
self.latent_dim = latent_dim
def forward(self, x):
return self.decoder(x)
class AudioAutoencoder(nn.Module):
def __init__(
self,
encoder,
decoder,
latent_dim,
downsampling_ratio,
sample_rate,
io_channels=2,
bottleneck: Bottleneck = None,
pretransform: Pretransform = None,
in_channels = None,
out_channels = None,
soft_clip = False
):
super().__init__()
self.downsampling_ratio = downsampling_ratio
self.sample_rate = sample_rate
self.latent_dim = latent_dim
self.io_channels = io_channels
self.in_channels = io_channels
self.out_channels = io_channels
self.min_length = self.downsampling_ratio
if in_channels is not None:
self.in_channels = in_channels
if out_channels is not None:
self.out_channels = out_channels
self.bottleneck = bottleneck
self.encoder = encoder
self.decoder = decoder
self.pretransform = pretransform
self.soft_clip = soft_clip
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
info = {}
if self.pretransform is not None and not skip_pretransform:
if self.pretransform.enable_grad:
if iterate_batch:
audios = []
for i in range(audio.shape[0]):
audios.append(self.pretransform.encode(audio[i:i+1]))
audio = torch.cat(audios, dim=0)
else:
audio = self.pretransform.encode(audio)
else:
with torch.no_grad():
if iterate_batch:
audios = []
for i in range(audio.shape[0]):
audios.append(self.pretransform.encode(audio[i:i+1]))
audio = torch.cat(audios, dim=0)
else:
audio = self.pretransform.encode(audio)
if self.encoder is not None:
if iterate_batch:
latents = []
for i in range(audio.shape[0]):
latents.append(self.encoder(audio[i:i+1]))
latents = torch.cat(latents, dim=0)
else:
latents = self.encoder(audio)
else:
latents = audio
if self.bottleneck is not None:
# TODO: Add iterate batch logic, needs to merge the info dicts
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
info.update(bottleneck_info)
if return_info:
return latents, info
return latents
def decode(self, latents, iterate_batch=False, **kwargs):
if self.bottleneck is not None:
if iterate_batch:
decoded = []
for i in range(latents.shape[0]):
decoded.append(self.bottleneck.decode(latents[i:i+1]))
decoded = torch.cat(decoded, dim=0)
else:
latents = self.bottleneck.decode(latents)
if iterate_batch:
decoded = []
for i in range(latents.shape[0]):
decoded.append(self.decoder(latents[i:i+1]))
decoded = torch.cat(decoded, dim=0)
else:
decoded = self.decoder(latents, **kwargs)
if self.pretransform is not None:
if self.pretransform.enable_grad:
if iterate_batch:
decodeds = []
for i in range(decoded.shape[0]):
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
decoded = torch.cat(decodeds, dim=0)
else:
decoded = self.pretransform.decode(decoded)
else:
with torch.no_grad():
if iterate_batch:
decodeds = []
for i in range(latents.shape[0]):
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
decoded = torch.cat(decodeds, dim=0)
else:
decoded = self.pretransform.decode(decoded)
if self.soft_clip:
decoded = torch.tanh(decoded)
return decoded
def encode_audio(self, audio, in_sr, **kwargs):
'''
Encode single audio tensor to latents, including preprocessing the audio to be compatible with the model
'''
if in_sr != self.sample_rate:
resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
audio = resample_tf(audio)
audio_length = audio.shape[-1]
pad_length = (self.min_length - (audio_length % self.min_length)) % self.min_length
# Pad with zeros to multiple of model's downsampling ratio
audio = F.pad(audio, (0, pad_length))
audio = prepare_audio(audio, in_sr=self.sample_rate, target_sr=self.sample_rate, target_length=audio.shape[1], target_channels=self.in_channels, device=audio.device)
# TODO: Add chunking logic
return self.encode(audio, **kwargs)
def decode_audio(self, latents, **kwargs):
'''
Decode latents to audio
'''
# TODO: Add chunking logic
return self.decode(latents, **kwargs)
class DiffusionAutoencoder(AudioAutoencoder):
def __init__(
self,
diffusion: ConditionedDiffusionModel,
diffusion_downsampling_ratio,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.diffusion = diffusion
self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
if self.encoder is not None:
# Shrink the initial encoder parameters to avoid saturated latents
with torch.no_grad():
for param in self.encoder.parameters():
param *= 0.5
def decode(self, latents, steps=100):
upsampled_length = latents.shape[2] * self.downsampling_ratio
if self.bottleneck is not None:
latents = self.bottleneck.decode(latents)
if self.decoder is not None:
latents = self.decode(latents)
# Upsample latents to match diffusion length
if latents.shape[2] != upsampled_length:
latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
if self.pretransform is not None:
if self.pretransform.enable_grad:
decoded = self.pretransform.decode(decoded)
else:
with torch.no_grad():
decoded = self.pretransform.decode(decoded)
return decoded
# AE factories
def create_encoder_from_config(encoder_config: Dict[str, Any]):
encoder_type = encoder_config.get("type", None)
assert encoder_type is not None, "Encoder type must be specified"
if encoder_type == "oobleck":
encoder = OobleckEncoder(
**encoder_config["config"]
)
elif encoder_type == "seanet":
from encodec.modules import SEANetEncoder
seanet_encoder_config = encoder_config["config"]
#SEANet encoder expects strides in reverse order
seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
encoder = SEANetEncoder(
**seanet_encoder_config
)
elif encoder_type == "dac":
dac_config = encoder_config["config"]
encoder = DACEncoderWrapper(**dac_config)
elif encoder_type == "local_attn":
from .local_attention import TransformerEncoder1D
local_attn_config = encoder_config["config"]
encoder = TransformerEncoder1D(
**local_attn_config
)
else:
raise ValueError(f"Unknown encoder type {encoder_type}")
requires_grad = encoder_config.get("requires_grad", True)
if not requires_grad:
for param in encoder.parameters():
param.requires_grad = False
return encoder
def create_decoder_from_config(decoder_config: Dict[str, Any]):
decoder_type = decoder_config.get("type", None)
assert decoder_type is not None, "Decoder type must be specified"
if decoder_type == "oobleck":
decoder = OobleckDecoder(
**decoder_config["config"]
)
elif decoder_type == "seanet":
from encodec.modules import SEANetDecoder
decoder = SEANetDecoder(
**decoder_config["config"]
)
elif decoder_type == "dac":
dac_config = decoder_config["config"]
decoder = DACDecoderWrapper(**dac_config)
elif decoder_type == "local_attn":
from .local_attention import TransformerDecoder1D
local_attn_config = decoder_config["config"]
decoder = TransformerDecoder1D(
**local_attn_config
)
else:
raise ValueError(f"Unknown decoder type {decoder_type}")
requires_grad = decoder_config.get("requires_grad", True)
if not requires_grad:
for param in decoder.parameters():
param.requires_grad = False
return decoder
def create_autoencoder_from_config(config: Dict[str, Any]):
ae_config = config["model"]
encoder = create_encoder_from_config(ae_config["encoder"])
decoder = create_decoder_from_config(ae_config["decoder"])
bottleneck = ae_config.get("bottleneck", None)
latent_dim = ae_config.get("latent_dim", None)
assert latent_dim is not None, "latent_dim must be specified in model config"
downsampling_ratio = ae_config.get("downsampling_ratio", None)
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
io_channels = ae_config.get("io_channels", None)
assert io_channels is not None, "io_channels must be specified in model config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "sample_rate must be specified in model config"
in_channels = ae_config.get("in_channels", None)
out_channels = ae_config.get("out_channels", None)
pretransform = ae_config.get("pretransform", None)
if pretransform is not None:
pretransform = create_pretransform_from_config(pretransform, sample_rate)
if bottleneck is not None:
bottleneck = create_bottleneck_from_config(bottleneck)
soft_clip = ae_config["decoder"].get("soft_clip", False)
return AudioAutoencoder(
encoder,
decoder,
io_channels=io_channels,
latent_dim=latent_dim,
downsampling_ratio=downsampling_ratio,
sample_rate=sample_rate,
bottleneck=bottleneck,
pretransform=pretransform,
in_channels=in_channels,
out_channels=out_channels,
soft_clip=soft_clip
)
def create_diffAE_from_config(config: Dict[str, Any]):
diffae_config = config["model"]
if "encoder" in diffae_config:
encoder = create_encoder_from_config(diffae_config["encoder"])
else:
encoder = None
if "decoder" in diffae_config:
decoder = create_decoder_from_config(diffae_config["decoder"])
else:
decoder = None
diffusion_model_type = diffae_config["diffusion"]["type"]
if diffusion_model_type == "DAU1d":
diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
elif diffusion_model_type == "adp_1d":
diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
elif diffusion_model_type == "dit":
diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
latent_dim = diffae_config.get("latent_dim", None)
assert latent_dim is not None, "latent_dim must be specified in model config"
downsampling_ratio = diffae_config.get("downsampling_ratio", None)
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
io_channels = diffae_config.get("io_channels", None)
assert io_channels is not None, "io_channels must be specified in model config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "sample_rate must be specified in model config"
bottleneck = diffae_config.get("bottleneck", None)
pretransform = diffae_config.get("pretransform", None)
if pretransform is not None:
pretransform = create_pretransform_from_config(pretransform, sample_rate)
if bottleneck is not None:
bottleneck = create_bottleneck_from_config(bottleneck)
diffusion_downsampling_ratio = None,
if diffusion_model_type == "DAU1d":
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
elif diffusion_model_type == "adp_1d":
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
elif diffusion_model_type == "dit":
diffusion_downsampling_ratio = 1
return DiffusionAutoencoder(
encoder=encoder,
decoder=decoder,
diffusion=diffusion,
io_channels=io_channels,
sample_rate=sample_rate,
latent_dim=latent_dim,
downsampling_ratio=downsampling_ratio,
diffusion_downsampling_ratio=diffusion_downsampling_ratio,
bottleneck=bottleneck,
pretransform=pretransform
)