Spaces:
Running
on
Zero
Running
on
Zero
from einops import rearrange | |
from torch import nn | |
class Pretransform(nn.Module): | |
def __init__(self, enable_grad=False, io_channels=2, ): | |
super().__init__() | |
self.io_channels = io_channels | |
self.encoded_channels = None | |
self.downsampling_ratio = None | |
self.enable_grad = enable_grad | |
def encode(self, x): | |
return x | |
def decode(self, z): | |
return z | |
class AutoencoderPretransform(Pretransform): | |
def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False): | |
super().__init__() | |
self.model = model | |
self.model.requires_grad_(False).eval() | |
self.scale=scale | |
self.downsampling_ratio = model.downsampling_ratio | |
self.io_channels = model.io_channels | |
self.sample_rate = model.sample_rate | |
self.model_half = model_half | |
self.iterate_batch = iterate_batch | |
self.encoded_channels = model.latent_dim | |
if self.model_half: | |
self.model.half() | |
def encode(self, x, **kwargs): | |
# print(f'encoder takes input {x.shape}') | |
if self.model_half: | |
x = x.half() | |
encoded = self.model.encode(x, iterate_batch=self.iterate_batch, **kwargs) | |
if self.model_half: | |
encoded = encoded.float() | |
out = encoded / self.scale | |
# print(f'encoder out: {out.shape}') | |
return encoded / self.scale | |
def decode(self, z, **kwargs): | |
z = z * self.scale | |
if self.model_half: | |
z = z.half() | |
decoded = self.model.decode(z, iterate_batch=self.iterate_batch, **kwargs) | |
if self.model_half: | |
decoded = decoded.float() | |
return decoded | |
def load_state_dict(self, state_dict, strict=True): | |
# print(f'load state dict {state_dict}') | |
self.model.load_state_dict(state_dict, strict=strict) | |
class WaveletPretransform(Pretransform): | |
def __init__(self, channels, levels, wavelet): | |
super().__init__() | |
from .wavelets import WaveletEncode1d, WaveletDecode1d | |
self.encoder = WaveletEncode1d(channels, levels, wavelet) | |
self.decoder = WaveletDecode1d(channels, levels, wavelet) | |
self.downsampling_ratio = 2 ** levels | |
self.io_channels = channels | |
self.encoded_channels = channels * self.downsampling_ratio | |
def encode(self, x): | |
return self.encoder(x) | |
def decode(self, z): | |
return self.decoder(z) | |
class PQMFPretransform(Pretransform): | |
def __init__(self, attenuation=100, num_bands=16): | |
super().__init__() | |
from .pqmf import PQMF | |
self.pqmf = PQMF(attenuation, num_bands) | |
def encode(self, x): | |
# x is (Batch x Channels x Time) | |
x = self.pqmf.forward(x) | |
# pqmf.forward returns (Batch x Channels x Bands x Time) | |
# but Pretransform needs Batch x Channels x Time | |
# so concatenate channels and bands into one axis | |
return rearrange(x, "b c n t -> b (c n) t") | |
def decode(self, x): | |
# x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) | |
x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) | |
# returns (Batch x Channels x Time) | |
return self.pqmf.inverse(x) | |
class PretrainedDACPretransform(Pretransform): | |
def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True): | |
super().__init__() | |
import dac | |
model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate) | |
self.model = dac.DAC.load(model_path) | |
self.quantize_on_decode = quantize_on_decode | |
if model_type == "44khz": | |
self.downsampling_ratio = 512 | |
else: | |
self.downsampling_ratio = 320 | |
self.io_channels = 1 | |
self.scale = scale | |
self.chunked = chunked | |
self.encoded_channels = self.model.latent_dim | |
def encode(self, x): | |
# print(f"Input to DAC encoder shape: {x.shape}, type: {x.dtype}") | |
latents = self.model.encoder(x) | |
# print(f"Latents shape after DAC encoder: {latents.shape}") | |
if self.quantize_on_decode: | |
output = latents | |
else: | |
z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) | |
output = z | |
if self.scale != 1.0: | |
output = output / self.scale | |
# print(f'output from DAC encoder: {x.shape}') | |
return output | |
def decode(self, z): | |
if self.scale != 1.0: | |
z = z * self.scale | |
if self.quantize_on_decode: | |
z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) | |
return self.model.decode(z) | |