# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math import typing as tp import torch from .modules import SEANetDecoder from .modules import SEANetEncoder from .quantization import ResidualVectorQuantizer ################################################################################ # Encodec neural audio codec ################################################################################ class Encodec(torch.nn.Module): """ Encodec neural audio codec proposed in "High Fidelity Neural Audio Compression" (https://arxiv.org/abs/2210.13438) by Défossez et al. """ def __init__( self, sample_rate: int, channels: int, causal: bool, model_norm: str, target_bandwidths: tp.Sequence[float], audio_normalize: bool, ratios: tp.List[int] = (8, 5, 4, 2), codebook_size: int = 1024, n_filters: int = 32, true_skip: bool = False, encoder_kwargs: tp.Dict = None, decoder_kwargs: tp.Dict = None, ): """ Parameters ---------- sample_rate : int Audio sample rate in Hz. channels : int Number of audio channels expected at input. causal : bool Whether to use a causal convolution layers in encoder/decoder. model_norm : str Type of normalization to use in encoder/decoder. target_bandwidths : tp.Sequence[float] List of target bandwidths in kb/s. audio_normalize : bool Whether to normalize encoded and decoded audio segments using simple scaling factors ratios : tp.List[int], optional List of downsampling ratios used in encoder/decoder, by default (8, 5, 4, 2) codebook_size : int, optional Size of residual vector quantizer codebooks, by default 1024 n_filters : int, optional Number of filters used in encoder/decoder, by default 32 true_skip : bool, optional Whether to use true skip connections in encoder/decoder rather than convolutional skip connections, by default False """ super().__init__() encoder_kwargs = encoder_kwargs or {} decoder_kwargs = decoder_kwargs or {} self.encoder = SEANetEncoder( channels=channels, causal=causal, norm=model_norm, ratios=ratios, n_filters=n_filters, true_skip=true_skip, **encoder_kwargs, ) self.decoder = SEANetDecoder( channels=channels, causal=causal, norm=model_norm, ratios=ratios, n_filters=n_filters, true_skip=true_skip, **decoder_kwargs, ) n_q = int( 1000 * target_bandwidths[-1] // (math.ceil(sample_rate / self.encoder.hop_length) * 10) ) self.n_q = n_q # Maximum number of quantizers self.quantizer = ResidualVectorQuantizer( dimension=self.encoder.dimension, n_q=n_q, bins=codebook_size, ) self.sample_rate = sample_rate self.normalize = audio_normalize self.channels = channels self.frame_rate = math.ceil(self.sample_rate / math.prod(self.encoder.ratios)) self.target_bandwidths = target_bandwidths self.bits_per_codebook = int(math.log2(self.quantizer.bins)) assert ( 2**self.bits_per_codebook == self.quantizer.bins ), "quantizer bins must be a power of 2." self.bandwidth = self.target_bandwidths[-1] def set_target_bandwidth(self, bandwidth: float): """ Set the target bandwidth for the codec by adjusting the number of residual vector quantizers used """ if bandwidth not in self.target_bandwidths: raise ValueError( f"This model doesn't support the bandwidth {bandwidth}. " f"Select one of {self.target_bandwidths}." ) self.bandwidth = bandwidth def encode(self, x: torch.Tensor) -> torch.Tensor: """ Map a given an audio waveform `x` to discrete residual latent codes. Parameters ---------- x : torch.Tensor Audio waveform of shape `(n_batch, n_channels, n_samples)`. Returns ------- codes : torch.Tensor Tensor of shape `(n_batch, n_codebooks, n_frames)`. """ assert x.dim() == 3 _, channels, length = x.shape assert 0 < channels <= 2 z = self.encoder(x) codes, z_O, z_o = self.quantizer.encode(z, self.frame_rate, self.bandwidth) codes = codes.transpose(0, 1) return codes, z_O, z_o, z def decode(self, codes: torch.Tensor): """ Decode quantized latents to obtain waveform audio. Parameters ---------- codes : torch.Tensor Tensor of shape `(n_batch, n_codebooks, n_frames)`. Returns ------- out : torch.Tensor Tensor of shape `(n_batch, n_channels, n_samples)`. """ codes = codes.transpose(0, 1) emb = self.quantizer.decode(codes) out = self.decoder(emb) return out