# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # # This implementation is inspired from # https://github.com/lucidrains/vector-quantize-pytorch # which is released under MIT License. Hereafter, the original license: # MIT License # # Copyright (c) 2020 Phil Wang # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # Modified by Yiwei Guo, 2024 # including Group VQ """Core vector quantization implementation.""" import typing as tp from einops import rearrange, repeat import torch from torch import nn import torch.nn.functional as F import logging import numpy as np import vec2wav2.distributed.distrib as distrib def default(val: tp.Any, d: tp.Any) -> tp.Any: return val if val is not None else d def ema_inplace(moving_avg, new, decay: float): moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): return (x + epsilon) / (x.sum() + n_categories * epsilon) def uniform_init(*shape: int): t = torch.empty(shape) nn.init.kaiming_uniform_(t) return t def sample_vectors(samples, num: int): num_samples, device = samples.shape[0], samples.device if num_samples >= num: indices = torch.randperm(num_samples, device=device)[:num] else: indices = torch.randint(0, num_samples, (num,), device=device) return samples[indices] def kmeans(samples, num_clusters: int, num_iters: int = 10): dim, dtype = samples.shape[-1], samples.dtype means = sample_vectors(samples, num_clusters) for _ in range(num_iters): diffs = rearrange(samples, "n d -> n () d") - rearrange( means, "c d -> () c d" ) dists = -(diffs ** 2).sum(dim=-1) buckets = dists.max(dim=-1).indices bins = torch.bincount(buckets, minlength=num_clusters) zero_mask = bins == 0 bins_min_clamped = bins.masked_fill(zero_mask, 1) new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) new_means = new_means / bins_min_clamped[..., None] means = torch.where(zero_mask[..., None], means, new_means) return means, bins def preprocess(x): x = rearrange(x, "... d -> (...) d") return x def postprocess_emb(embed_ind, shape): return embed_ind.view(*shape[:-1]) class EuclideanCodebook(nn.Module): """Codebook with Euclidean distance. Args: dim (int): Dimension. codebook_size (int): Codebook size. kmeans_init (bool): Whether to use k-means to initialize the codebooks. If set to true, run the k-means algorithm on the first training batch and use the learned centroids as initialization. kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. decay (float): Decay for exponential moving average over the codebooks. epsilon (float): Epsilon value for numerical stability. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ def __init__( self, dim: int, codebook_size: int, kmeans_init: int = False, kmeans_iters: int = 10, decay: float = 0.99, epsilon: float = 1e-5, threshold_ema_dead_code: int = 2, init_codebook=None, perform_expire: bool = False, perform_ema_update: bool = False, ): super().__init__() self.decay = decay init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if (not kmeans_init) and (init_codebook is not None) else torch.zeros embed = init_fn(codebook_size, dim) if init_codebook is not None: # So that we don't need to care about mask issues in forward embed[:] = torch.from_numpy(init_codebook) # the saved codebook is [V, D] logging.info(f"Initiated EuclideanCodebook from {init_codebook}") self.codebook_size = codebook_size self.kmeans_iters = kmeans_iters self.epsilon = epsilon self.threshold_ema_dead_code = threshold_ema_dead_code self.register_buffer("inited", torch.Tensor([(not kmeans_init) or (init_codebook is not None)])) self.register_buffer("cluster_size", torch.zeros(codebook_size)) self.register_buffer("embed", embed) self.register_buffer("embed_avg", embed.clone()) self.training = True self.perform_expire = perform_expire self.perform_ema_update = perform_ema_update def init_embed_(self, data): if self.inited: return embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) self.embed.data.copy_(embed) self.embed_avg.data.copy_(embed.clone()) self.cluster_size.data.copy_(cluster_size) self.inited.data.copy_(torch.Tensor([True])) # Make sure all buffers across workers are in sync after initialization distrib.broadcast_tensors(self.buffers()) def replace_(self, samples, mask): modified_codebook = torch.where( mask[..., None], sample_vectors(samples, self.codebook_size), self.embed ) self.embed.data.copy_(modified_codebook) def expire_codes_(self, batch_samples): if self.threshold_ema_dead_code == 0: return expired_codes = self.cluster_size < self.threshold_ema_dead_code if not torch.any(expired_codes): return batch_samples = rearrange(batch_samples, "... d -> (...) d") self.replace_(batch_samples, mask=expired_codes) distrib.broadcast_tensors(self.buffers()) def quantize(self, x): embed = self.embed.t() dist = -( x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True) ) embed_ind = dist.max(dim=-1).indices return embed_ind def dequantize(self, embed_ind): quantize = F.embedding(embed_ind, self.embed) return quantize def encode(self, x): shape = x.shape # pre-process x = preprocess(x) # quantize embed_ind = self.quantize(x) # post-process embed_ind = postprocess_emb(embed_ind, shape) return embed_ind def decode(self, embed_ind): quantize = self.dequantize(embed_ind) return quantize def forward(self, x): shape, dtype = x.shape, x.dtype x = preprocess(x) self.init_embed_(x) embed_ind = self.quantize(x) embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) embed_ind = postprocess_emb(embed_ind, shape) quantize = self.dequantize(embed_ind) if self.training: # We do the expiry of code at that point as buffers are in sync # and all the workers will take the same decision. if self.perform_expire: self.expire_codes_(x) if self.perform_ema_update: ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) embed_sum = x.t() @ embed_onehot ema_inplace(self.embed_avg, embed_sum.t(), self.decay) cluster_size = ( laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum() ) embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) self.embed.data.copy_(embed_normalized) return quantize, embed_ind class VectorQuantization(nn.Module): """Vector quantization implementation. Currently, supports only euclidean distance. Args: dim (int): Dimension codebook_size (int): Codebook size codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. decay (float): Decay for exponential moving average over the codebooks. epsilon (float): Epsilon value for numerical stability. kmeans_init (bool): Whether to use kmeans to initialize the codebooks. kmeans_iters (int): Number of iterations used for kmeans initialization. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. commitment_weight (float): Weight for commitment loss. """ def __init__( self, dim: int, codebook_size: int, codebook_dim: tp.Optional[int] = None, decay: float = 0.99, epsilon: float = 1e-5, kmeans_init: bool = True, kmeans_iters: int = 50, threshold_ema_dead_code: int = 2, commitment_weight: float = 1., init_codebook=None, ): super().__init__() _codebook_dim: int = default(codebook_dim, dim) requires_projection = _codebook_dim != dim self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity()) self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity()) self.epsilon = epsilon self.commitment_weight = commitment_weight self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, decay=decay, epsilon=epsilon, threshold_ema_dead_code=threshold_ema_dead_code, init_codebook=init_codebook) self.codebook_size = codebook_size self.training = True @property def codebook(self): return self._codebook.embed def encode(self, x): x = rearrange(x, "b d n -> b n d") x = self.project_in(x) embed_in = self._codebook.encode(x) return embed_in def decode(self, embed_ind): quantize = self._codebook.decode(embed_ind) quantize = self.project_out(quantize) quantize = rearrange(quantize, "b n d -> b d n") return quantize def forward(self, x): # NOTE: we bypass the rearranging of the tensor, and directly input [something, D] tensor. device = x.device # x = rearrange(x, "b d n -> b n d") x = self.project_in(x) quantize, embed_ind = self._codebook(x) if self.training: quantize = x + (quantize - x).detach() # NOTE: pass grad. loss = torch.tensor(0.0, device=device, requires_grad=self.training) if self.training: if self.commitment_weight > 0: commit_loss = F.mse_loss(quantize.detach(), x) loss = loss + commit_loss * self.commitment_weight quantize = self.project_out(quantize) # quantize = rearrange(quantize, "b n d -> b d n") return quantize, embed_ind, loss class ResidualVectorQuantization(nn.Module): """Residual vector quantization implementation. Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ def __init__(self, *, num_quantizers, **kwargs): super().__init__() self.layers = nn.ModuleList( [VectorQuantization(**kwargs) for _ in range(num_quantizers)] ) def forward(self, x, n_q: tp.Optional[int] = None): quantized_out = 0.0 residual = x all_losses = [] all_indices = [] n_q = n_q or len(self.layers) for layer in self.layers[:n_q]: quantized, indices, loss = layer(residual) residual = residual - quantized quantized_out = quantized_out + quantized all_indices.append(indices) all_losses.append(loss) out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) return quantized_out, out_indices, out_losses def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: residual = x all_indices = [] n_q = n_q or len(self.layers) for layer in self.layers[:n_q]: indices = layer.encode(residual) quantized = layer.decode(indices) residual = residual - quantized all_indices.append(indices) out_indices = torch.stack(all_indices) return out_indices def decode(self, q_indices: torch.Tensor) -> torch.Tensor: quantized_out = torch.tensor(0.0, device=q_indices.device) for i, indices in enumerate(q_indices): layer = self.layers[i] quantized = layer.decode(indices) quantized_out = quantized_out + quantized return quantized_out class GroupVectorQuantization(nn.Module): def __init__(self, *, num_quantizers, init_codebook=None, **kwargs): super().__init__() if (init_codebook is not None) and (len(init_codebook) > 0): init_codebook = np.load(init_codebook) # [G, V, D] else: init_codebook = None self.groups = nn.ModuleList( [VectorQuantization(init_codebook=init_codebook[group_id] if init_codebook is not None else None, **kwargs) for group_id in range(num_quantizers)] ) self.n_group = num_quantizers def forward(self, x): # x is assumed to have [something, D] shape. num, dim_in = x.shape dim_per_q = dim_in // self.n_group quantized_result = torch.zeros_like(x) quantized_indices = torch.zeros(num, self.n_group).long().to(x.device) loss = torch.tensor(0.0).to(x.device) for group_i in range(self.n_group): dim_range = slice(group_i * dim_per_q, (group_i+1) * dim_per_q) quantized_i, indices_i, loss_i = self.groups[group_i](x[:, dim_range]) quantized_result[:, dim_range] = quantized_i quantized_indices[:, group_i] = indices_i loss = loss + loss_i return quantized_result, quantized_indices, loss def encode(self, x): return self.forward(x)[1] # only return indices def decode(self, q_indices): # [B, G, L] quantized_out = [] # tensors of shape [B, D, L] for group_i in range(q_indices.shape[1]): quantized_out.append(self.groups[group_i].decode(q_indices[:, group_i, :])) return torch.concat(quantized_out, dim=1)