|
import logging |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.distributed as dist |
|
|
|
from einops import rearrange |
|
from torch.nn.utils import weight_norm |
|
|
|
def WNConv1d(*args, **kwargs): |
|
return weight_norm(nn.Conv1d(*args, **kwargs)) |
|
|
|
def WNConvTranspose1d(*args, **kwargs): |
|
return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
|
|
|
def sample_vectors(samples, num): |
|
|
|
num_samples, device = samples.shape[0], samples.device |
|
if num_samples >= num: |
|
indices = torch.randperm(num_samples, device=device)[:num] |
|
else: |
|
indices = torch.randint(0, num_samples, (num,), device=device) |
|
return samples[indices].float() |
|
|
|
def kmeans(samples, num_clusters, num_iters=10): |
|
|
|
dim, dtype = samples.shape[-1], torch.float32 |
|
means = sample_vectors(samples, num_clusters).float() |
|
|
|
for _ in range(num_iters): |
|
dists = -(samples.float().pow(2).sum(1, keepdim=True) - |
|
2 * samples.float() @ means.t() + |
|
means.t().float().pow(2).sum(0, keepdim=True)) |
|
|
|
buckets = dists.max(dim=-1).indices |
|
bins = torch.bincount(buckets, minlength=num_clusters) |
|
zero_mask = bins == 0 |
|
bins_min_clamped = bins.masked_fill(zero_mask, 1) |
|
|
|
new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32) |
|
new_means.scatter_add_(0, buckets.unsqueeze(1).expand(-1, dim), samples.float()) |
|
new_means = new_means / bins_min_clamped[..., None] |
|
means = torch.where(zero_mask[..., None], means, new_means) |
|
|
|
|
|
dists = -(samples.float().pow(2).sum(1, keepdim=True) - |
|
2 * samples.float() @ means.t() + |
|
means.t().float().pow(2).sum(0, keepdim=True)) |
|
buckets = dists.max(dim=-1).indices |
|
bins = torch.bincount(buckets, minlength=num_clusters).float() |
|
|
|
return means, bins |
|
|
|
class VectorQuantize(nn.Module): |
|
def __init__( |
|
self, |
|
input_dim, |
|
codebook_size, |
|
codebook_dim, |
|
commitment=1.0, |
|
decay=0.99, |
|
epsilon=1e-5, |
|
threshold_ema_dead=2, |
|
kmeans_init=True, |
|
kmeans_iters=10, |
|
): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.codebook_size = codebook_size |
|
self.codebook_dim = codebook_dim |
|
self.commitment = commitment |
|
self.decay = decay |
|
self.epsilon = epsilon |
|
self.threshold_ema_dead = threshold_ema_dead |
|
self.kmeans_init = kmeans_init |
|
self.kmeans_iters = kmeans_iters |
|
|
|
if self.input_dim != self.codebook_dim: |
|
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) |
|
self.out_project = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1) |
|
else: |
|
self.in_project = nn.Identity() |
|
self.out_project = nn.Identity() |
|
|
|
|
|
init_fn = torch.zeros if kmeans_init else lambda x, y: torch.randn(x, y) |
|
self.register_buffer("codebook", init_fn(codebook_size, codebook_dim).float()) |
|
self.register_buffer("inited", torch.tensor([not kmeans_init], dtype=torch.bool)) |
|
self.register_buffer("cluster_size", torch.zeros(codebook_size).float()) |
|
self.register_buffer("embed_avg", self.codebook.clone().float()) |
|
|
|
def ema_update(self, encodings, embed_onehot): |
|
|
|
"""Update codebook using EMA""" |
|
encodings = encodings.float() |
|
embed_onehot = embed_onehot.float() |
|
cluster_size_new = embed_onehot.sum(0) |
|
embed_sum = encodings.t() @ embed_onehot |
|
|
|
|
|
if dist.is_initialized(): |
|
dist.all_reduce(cluster_size_new, op=dist.ReduceOp.SUM) |
|
dist.all_reduce(embed_sum, op=dist.ReduceOp.SUM) |
|
|
|
ema_inplace(self.cluster_size, cluster_size_new, self.decay) |
|
ema_inplace(self.embed_avg, embed_sum.t(), self.decay) |
|
|
|
|
|
cluster_size = (self.cluster_size + self.epsilon) / (self.cluster_size.sum() + self.codebook_size * self.epsilon) |
|
cluster_size = cluster_size * self.cluster_size.sum() |
|
self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1)) |
|
|
|
def replace_dead_codes(self, encodings): |
|
|
|
"""Replace dead codes with random samples from current batch""" |
|
if self.threshold_ema_dead == 0: |
|
return |
|
|
|
dead_mask = self.cluster_size < self.threshold_ema_dead |
|
if dead_mask.any(): |
|
if dist.is_initialized() and dist.get_rank() == 0: |
|
samples = sample_vectors(encodings.float(), self.codebook_size) |
|
else: |
|
samples = torch.zeros_like(self.codebook).float() |
|
|
|
|
|
if dist.is_initialized(): |
|
dist.broadcast(samples, src=0) |
|
|
|
self.codebook[dead_mask] = samples[:dead_mask.sum()].to(self.codebook.dtype) |
|
|
|
def init_codebook(self, encodings): |
|
|
|
"""Initialize codebook with k-means and update cluster_size""" |
|
if self.inited.item(): |
|
return |
|
|
|
if dist.is_initialized() and dist.get_rank() == 0: |
|
embed, cluster_sizes = kmeans(encodings.float(), self.codebook_size, self.kmeans_iters) |
|
else: |
|
embed = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device).float() |
|
cluster_sizes = torch.zeros(self.codebook_size, device=encodings.device, dtype=torch.float32) |
|
|
|
|
|
if dist.is_initialized(): |
|
dist.broadcast(embed, src=0) |
|
dist.broadcast(cluster_sizes, src=0) |
|
|
|
self.codebook.copy_(embed) |
|
self.embed_avg.copy_(embed.clone()) |
|
self.cluster_size.copy_(cluster_sizes.float()) |
|
self.inited.fill_(True) |
|
|
|
def forward(self, z): |
|
|
|
z = z.float() |
|
z_e = self.in_project(z).float() |
|
|
|
|
|
encodings = rearrange(z_e, "b d t -> (b t) d").float() |
|
|
|
|
|
if self.kmeans_init and not self.inited.item(): |
|
self.init_codebook(encodings) |
|
|
|
|
|
dist = (encodings.pow(2).sum(1, keepdim=True) - |
|
2 * encodings @ self.codebook.float().t() + |
|
self.codebook.float().pow(2).sum(1, keepdim=True).t()) |
|
|
|
|
|
indices = (-dist).max(1)[1] |
|
indices = rearrange(indices, "(b t) -> b t", b=z.size(0)) |
|
|
|
|
|
z_q = self.decode_code(indices).float() |
|
|
|
|
|
commit_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) * self.commitment |
|
|
|
|
|
if self.training and torch.is_grad_enabled(): |
|
embed_onehot = F.one_hot(indices.view(-1), self.codebook_size).float() |
|
self.ema_update(encodings, embed_onehot) |
|
self.replace_dead_codes(encodings) |
|
|
|
|
|
z_q = z_e + (z_q - z_e).detach() |
|
z_q = self.out_project(z_q).float() |
|
|
|
return z_q, commit_loss, torch.tensor(0.0, device=z.device, dtype=torch.float32), indices, z |
|
|
|
def decode_code(self, embed_id): |
|
return F.embedding(embed_id, self.codebook).transpose(1, 2).float() |
|
|
|
class ResidualVQ(nn.Module): |
|
def __init__( |
|
self, |
|
input_dim: int = 1280, |
|
rvq_dim = None, |
|
output_dim: int = None, |
|
num_quantizers: int = 32, |
|
codebook_size: int = 1024, |
|
codebook_dim: int = 8, |
|
quantizer_dropout: float = 0.5, |
|
decay=0.99, |
|
epsilon=1e-5, |
|
threshold_ema_dead=2, |
|
kmeans_init=True, |
|
kmeans_iters=10, |
|
skip_rvq_ratio: float = 0.0, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
|
|
self.num_quantizers = num_quantizers |
|
self.codebook_size = codebook_size |
|
self.codebook_dim = codebook_dim |
|
self.quantizer_dropout = quantizer_dropout |
|
self.skip_rvq_ratio = skip_rvq_ratio |
|
self.rvq_dim = rvq_dim |
|
|
|
self.input_proj = WNConv1d(input_dim, rvq_dim, kernel_size=1) if input_dim != rvq_dim else nn.Identity() |
|
self.output_proj = WNConv1d(rvq_dim, output_dim, kernel_size=1) if rvq_dim != output_dim else nn.Identity() |
|
|
|
self.quantizers = nn.ModuleList( |
|
[ |
|
VectorQuantize( |
|
input_dim=rvq_dim, |
|
codebook_size=codebook_size, |
|
codebook_dim=codebook_dim, |
|
decay=decay, |
|
epsilon=epsilon, |
|
threshold_ema_dead=threshold_ema_dead, |
|
kmeans_init=kmeans_init, |
|
kmeans_iters=kmeans_iters, |
|
**kwargs, |
|
) |
|
for _ in range(num_quantizers) |
|
] |
|
) |
|
|
|
def forward(self, z, input_length, n_quantizers: int = None): |
|
z = self.input_proj(z) |
|
|
|
with torch.autocast('cuda', enabled = False): |
|
batch_size, _, max_time = z.shape |
|
mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) |
|
|
|
quantized_out = torch.zeros_like(z, dtype=torch.float32) |
|
residual = z.clone().float() |
|
|
|
all_commit_losses = [] |
|
all_indices = [] |
|
all_quantized = [] |
|
|
|
n_quantizers = n_quantizers or self.num_quantizers |
|
|
|
|
|
skip_mask = None |
|
if self.training and torch.is_grad_enabled() and self.skip_rvq_ratio > 0: |
|
|
|
skip_mask = torch.rand(batch_size, device=z.device) < self.skip_rvq_ratio |
|
|
|
if skip_mask.all(): |
|
skip_mask[0] = False |
|
|
|
if self.training and torch.is_grad_enabled(): |
|
n_quantizers_tensor = torch.ones((z.shape[0],), dtype=torch.float32, device=z.device) * self.num_quantizers + 1 |
|
dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],), dtype=torch.float32, device=z.device) |
|
n_dropout = int(z.shape[0] * self.quantizer_dropout) |
|
n_quantizers_tensor[:n_dropout] = dropout[:n_dropout] |
|
else: |
|
n_quantizers_tensor = torch.full((z.shape[0],), n_quantizers, dtype=torch.float32, device=z.device) |
|
|
|
for i, quantizer in enumerate(self.quantizers): |
|
if not self.training and i >= n_quantizers: |
|
break |
|
|
|
masked_residual = residual * mask.unsqueeze(1) |
|
|
|
|
|
if self.training and skip_mask is not None and skip_mask.any(): |
|
z_q_i = torch.zeros_like(masked_residual, dtype=torch.float32) |
|
commit_loss_i = torch.zeros(batch_size, device=z.device, dtype=torch.float32) |
|
indices_i = torch.zeros(batch_size, max_time, device=z.device, dtype=torch.long) |
|
z_e_i = torch.zeros_like(masked_residual, dtype=torch.float32) |
|
|
|
|
|
non_skipped_mask = ~skip_mask |
|
if non_skipped_mask.any(): |
|
z_q_i_non_skipped, commit_loss_i_non_skipped, _, indices_i_non_skipped, z_e_i_non_skipped = quantizer( |
|
masked_residual[non_skipped_mask].float() |
|
) |
|
z_q_i[non_skipped_mask] = z_q_i_non_skipped |
|
commit_loss_i[non_skipped_mask] = commit_loss_i_non_skipped |
|
indices_i[non_skipped_mask] = indices_i_non_skipped |
|
z_e_i[non_skipped_mask] = z_e_i_non_skipped |
|
else: |
|
z_q_i, commit_loss_i, _, indices_i, z_e_i = quantizer(masked_residual.float()) |
|
|
|
quantizer_mask = (torch.full((z.shape[0],), i, device=z.device, dtype=torch.float32) < n_quantizers_tensor) |
|
update_mask = (mask & quantizer_mask.unsqueeze(-1)).unsqueeze(1) |
|
|
|
|
|
if skip_mask is not None: |
|
skip_mask_expanded = skip_mask.unsqueeze(1).unsqueeze(2) |
|
z_q_i = torch.where(skip_mask_expanded, masked_residual, z_q_i) |
|
commit_loss_i = torch.where(skip_mask, torch.zeros_like(commit_loss_i), commit_loss_i) |
|
|
|
quantized_out = quantized_out + z_q_i * update_mask |
|
|
|
residual_fp32 = residual.to(dtype=torch.float32) |
|
z_q_i_fp32 = z_q_i.to(dtype=torch.float32) |
|
residual_fp32 = residual_fp32 - z_q_i_fp32 * update_mask |
|
residual = residual_fp32.to(dtype=torch.float32) |
|
|
|
valid_mask = mask & quantizer_mask.unsqueeze(-1) |
|
if valid_mask.any(): |
|
commit_loss_i = (commit_loss_i * quantizer_mask).sum() / quantizer_mask.sum() |
|
else: |
|
commit_loss_i = torch.tensor(0.0, device=z.device, dtype=torch.float32) |
|
|
|
all_commit_losses.append(commit_loss_i) |
|
all_indices.append(indices_i) |
|
all_quantized.append(z_q_i) |
|
|
|
all_commit_losses = torch.stack(all_commit_losses) |
|
all_indices = torch.stack(all_indices) |
|
all_quantized = torch.stack(all_quantized) |
|
|
|
output_length = input_length |
|
|
|
quantized_out = self.output_proj(quantized_out) |
|
|
|
return ( |
|
quantized_out, |
|
all_indices, |
|
all_commit_losses, |
|
all_quantized, |
|
output_length, |
|
) |
|
|
|
def decode_codes(self, codes): |
|
"""Decode codes from multiple quantizers to embeddings. |
|
|
|
Args: |
|
codes: Tensor of shape (nq, B, T) containing code indices for each quantizer. |
|
|
|
Returns: |
|
emb: Tensor of shape (B, D, T) representing the decoded embeddings. |
|
""" |
|
nq, B, T = codes.shape |
|
device = codes.device |
|
emb = torch.zeros(B, self.rvq_dim, T, device=device, dtype=torch.float32) |
|
|
|
for i, quantizer in enumerate(self.quantizers[:nq]): |
|
code_i = codes[i] |
|
quantized_i = quantizer.decode_code(code_i) |
|
emb += quantized_i |
|
|
|
emb = self.output_proj(emb) |
|
return emb |
|
|
|
|
|
def ema_inplace(moving_avg, new, decay): |
|
|
|
"""Update exponential moving average in-place""" |
|
moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) |