yhzx233's picture
feat: app.py
ea174b0
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from einops import rearrange
from torch.nn.utils import weight_norm
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
def sample_vectors(samples, num):
# samples: (N, D), num_samples: N, feature dim: D
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices].float() # (num, D), ensure fp32
def kmeans(samples, num_clusters, num_iters=10):
# samples: (N, D), N samples with D dimensions
dim, dtype = samples.shape[-1], torch.float32 # Force fp32
means = sample_vectors(samples, num_clusters).float() # (num_clusters, D), ensure fp32
for _ in range(num_iters):
dists = -(samples.float().pow(2).sum(1, keepdim=True) - # (N, 1), ensure fp32
2 * samples.float() @ means.t() + # (N, num_clusters), ensure fp32
means.t().float().pow(2).sum(0, keepdim=True)) # (1, num_clusters), ensure fp32
# dists: (N, num_clusters)
buckets = dists.max(dim=-1).indices # (N)
bins = torch.bincount(buckets, minlength=num_clusters) # (num_clusters)
zero_mask = bins == 0 # (num_clusters)
bins_min_clamped = bins.masked_fill(zero_mask, 1) # (num_clusters)
new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32) # (num_clusters, D), ensure fp32
new_means.scatter_add_(0, buckets.unsqueeze(1).expand(-1, dim), samples.float()) # (num_clusters, D), ensure fp32
new_means = new_means / bins_min_clamped[..., None] # (num_clusters, D)
means = torch.where(zero_mask[..., None], means, new_means) # (num_clusters, D)
# Final cluster assignments for returning cluster sizes
dists = -(samples.float().pow(2).sum(1, keepdim=True) -
2 * samples.float() @ means.t() +
means.t().float().pow(2).sum(0, keepdim=True)) # (N, num_clusters), ensure fp32
buckets = dists.max(dim=-1).indices # (N)
bins = torch.bincount(buckets, minlength=num_clusters).float() # (num_clusters), ensure fp32
return means, bins # (num_clusters, D), (num_clusters)
class VectorQuantize(nn.Module):
def __init__(
self,
input_dim,
codebook_size,
codebook_dim,
commitment=1.0,
decay=0.99, # EMA decay
epsilon=1e-5, # Laplace smoothing epsilon
threshold_ema_dead=2, # Dead code threshold
kmeans_init=True, # Use kmeans initialization
kmeans_iters=10, # Kmeans iterations
):
super().__init__()
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment = commitment
self.decay = decay
self.epsilon = epsilon
self.threshold_ema_dead = threshold_ema_dead
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
if self.input_dim != self.codebook_dim:
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) # (B, D, T) -> (B, D', T)
self.out_project = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1) # (B, D', T) -> (B, D, T)
else:
self.in_project = nn.Identity()
self.out_project = nn.Identity()
# Initialize codebook and EMA buffers
init_fn = torch.zeros if kmeans_init else lambda x, y: torch.randn(x, y)
self.register_buffer("codebook", init_fn(codebook_size, codebook_dim).float()) # (codebook_size, D'), ensure fp32
self.register_buffer("inited", torch.tensor([not kmeans_init], dtype=torch.bool)) # (1)
self.register_buffer("cluster_size", torch.zeros(codebook_size).float()) # (codebook_size), ensure fp32
self.register_buffer("embed_avg", self.codebook.clone().float()) # (codebook_size, D'), ensure fp32
def ema_update(self, encodings, embed_onehot):
# encodings: (B*T, D'), embed_onehot: (B*T, codebook_size)
"""Update codebook using EMA"""
encodings = encodings.float() # Ensure fp32
embed_onehot = embed_onehot.float() # Ensure fp32
cluster_size_new = embed_onehot.sum(0) # (codebook_size)
embed_sum = encodings.t() @ embed_onehot # (D', codebook_size)
# Distributed reduction
if dist.is_initialized():
dist.all_reduce(cluster_size_new, op=dist.ReduceOp.SUM)
dist.all_reduce(embed_sum, op=dist.ReduceOp.SUM)
ema_inplace(self.cluster_size, cluster_size_new, self.decay) # (codebook_size)
ema_inplace(self.embed_avg, embed_sum.t(), self.decay) # (codebook_size, D')
# Laplace smoothing
cluster_size = (self.cluster_size + self.epsilon) / (self.cluster_size.sum() + self.codebook_size * self.epsilon) # (codebook_size)
cluster_size = cluster_size * self.cluster_size.sum() # (codebook_size)
self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1)) # (codebook_size, D')
def replace_dead_codes(self, encodings):
# encodings: (B*T, D')
"""Replace dead codes with random samples from current batch"""
if self.threshold_ema_dead == 0:
return
dead_mask = self.cluster_size < self.threshold_ema_dead # (codebook_size)
if dead_mask.any():
if dist.is_initialized() and dist.get_rank() == 0:
samples = sample_vectors(encodings.float(), self.codebook_size) # (codebook_size, D'), ensure fp32
else:
samples = torch.zeros_like(self.codebook).float() # Placeholder, ensure fp32
# Broadcast samples
if dist.is_initialized():
dist.broadcast(samples, src=0)
self.codebook[dead_mask] = samples[:dead_mask.sum()].to(self.codebook.dtype) # Update dead codes
def init_codebook(self, encodings):
# encodings: (B*T, D')
"""Initialize codebook with k-means and update cluster_size"""
if self.inited.item():
return
if dist.is_initialized() and dist.get_rank() == 0:
embed, cluster_sizes = kmeans(encodings.float(), self.codebook_size, self.kmeans_iters) # (codebook_size, D'), (codebook_size), ensure fp32
else:
embed = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device).float() # ensure fp32
cluster_sizes = torch.zeros(self.codebook_size, device=encodings.device, dtype=torch.float32) # ensure fp32
# Broadcast results
if dist.is_initialized():
dist.broadcast(embed, src=0)
dist.broadcast(cluster_sizes, src=0)
self.codebook.copy_(embed) # (codebook_size, D')
self.embed_avg.copy_(embed.clone()) # (codebook_size, D')
self.cluster_size.copy_(cluster_sizes.float()) # (codebook_size)
self.inited.fill_(True)
def forward(self, z): # z: (B, D, T)
# logging.info(f"{self.cluster_size = }, {self.codebook = }, {self.embed_avg = }, {self.inited = }")
z = z.float() # Ensure fp32
z_e = self.in_project(z).float() # (B, D', T), ensure fp32
# Rearrange for quantization
encodings = rearrange(z_e, "b d t -> (b t) d").float() # (B*T, D'), ensure fp32
# Initialize codebook if needed
if self.kmeans_init and not self.inited.item():
self.init_codebook(encodings)
# Quantization
dist = (encodings.pow(2).sum(1, keepdim=True) - # (B*T, 1)
2 * encodings @ self.codebook.float().t() + # (B*T, codebook_size)
self.codebook.float().pow(2).sum(1, keepdim=True).t()) # (1, codebook_size)
# dist: (B*T, codebook_size)
indices = (-dist).max(1)[1] # (B*T)
indices = rearrange(indices, "(b t) -> b t", b=z.size(0)) # (B, T)
# Get quantized vectors
z_q = self.decode_code(indices).float() # (B, D', T), ensure fp32
# Commitment loss
commit_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) * self.commitment # (B)
# EMA updates and dead code replacement during training
if self.training and torch.is_grad_enabled():
embed_onehot = F.one_hot(indices.view(-1), self.codebook_size).float() # (B*T, codebook_size), ensure fp32
self.ema_update(encodings, embed_onehot)
self.replace_dead_codes(encodings)
# Straight-through estimator
z_q = z_e + (z_q - z_e).detach() # (B, D', T)
z_q = self.out_project(z_q).float() # (B, D, T), ensure fp32
return z_q, commit_loss, torch.tensor(0.0, device=z.device, dtype=torch.float32), indices, z # (B, D, T), (B), scalar, (B, T), (B, D', T)
def decode_code(self, embed_id): # embed_id: (B, T)
return F.embedding(embed_id, self.codebook).transpose(1, 2).float() # (B, D', T), ensure fp32
class ResidualVQ(nn.Module):
def __init__(
self,
input_dim: int = 1280, # Input dimension, unrelated to RVQ
rvq_dim = None, # RVQ dimension. If different from input_dim/output_dim, will add input_dim->rvq_dim/rvq_dim->output_dim projection
output_dim: int = None, # Output dimension, unrelated to RVQ
num_quantizers: int = 32,
codebook_size: int = 1024,
codebook_dim: int = 8, # Dimension of each codebook. If different from rvq_dim, will add rvq_dim->codebook_dim and codebook_dim->rvq_dim projections
quantizer_dropout: float = 0.5,
decay=0.99,
epsilon=1e-5,
threshold_ema_dead=2,
kmeans_init=True,
kmeans_iters=10,
skip_rvq_ratio: float = 0.0, # New parameter: probability of skipping RVQ
**kwargs,
):
super().__init__()
self.input_dim = input_dim
self.num_quantizers = num_quantizers
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer_dropout = quantizer_dropout
self.skip_rvq_ratio = skip_rvq_ratio # Store skip probability
self.rvq_dim = rvq_dim
self.input_proj = WNConv1d(input_dim, rvq_dim, kernel_size=1) if input_dim != rvq_dim else nn.Identity()
self.output_proj = WNConv1d(rvq_dim, output_dim, kernel_size=1) if rvq_dim != output_dim else nn.Identity()
self.quantizers = nn.ModuleList(
[
VectorQuantize(
input_dim=rvq_dim,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
decay=decay,
epsilon=epsilon,
threshold_ema_dead=threshold_ema_dead,
kmeans_init=kmeans_init,
kmeans_iters=kmeans_iters,
**kwargs,
)
for _ in range(num_quantizers)
]
)
def forward(self, z, input_length, n_quantizers: int = None): # z: (B, D, T), input_length: (B)
z = self.input_proj(z)
with torch.autocast('cuda', enabled = False):
batch_size, _, max_time = z.shape
mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) # (B, T)
quantized_out = torch.zeros_like(z, dtype=torch.float32) # (B, D, T), ensure fp32
residual = z.clone().float() # (B, D, T), ensure fp32
all_commit_losses = []
all_indices = []
all_quantized = []
n_quantizers = n_quantizers or self.num_quantizers
# Randomly decide whether to skip RVQ during training
skip_mask = None
if self.training and torch.is_grad_enabled() and self.skip_rvq_ratio > 0:
# Generate random mask with skip_rvq_ratio probability
skip_mask = torch.rand(batch_size, device=z.device) < self.skip_rvq_ratio # (B,)
# If all samples are skipped, force the first sample to be unskipped
if skip_mask.all():
skip_mask[0] = False # Ensure at least one sample (index 0) is not skipped
if self.training and torch.is_grad_enabled():
n_quantizers_tensor = torch.ones((z.shape[0],), dtype=torch.float32, device=z.device) * self.num_quantizers + 1 # (B)
dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],), dtype=torch.float32, device=z.device) # (B)
n_dropout = int(z.shape[0] * self.quantizer_dropout)
n_quantizers_tensor[:n_dropout] = dropout[:n_dropout] # (B)
else:
n_quantizers_tensor = torch.full((z.shape[0],), n_quantizers, dtype=torch.float32, device=z.device) # (B)
for i, quantizer in enumerate(self.quantizers):
if not self.training and i >= n_quantizers:
break
masked_residual = residual * mask.unsqueeze(1) # (B, D, T)
# If skipping RVQ, directly use input value
if self.training and skip_mask is not None and skip_mask.any():
z_q_i = torch.zeros_like(masked_residual, dtype=torch.float32) # (B, D, T), ensure fp32
commit_loss_i = torch.zeros(batch_size, device=z.device, dtype=torch.float32) # (B), ensure fp32
indices_i = torch.zeros(batch_size, max_time, device=z.device, dtype=torch.long) # (B, T)
z_e_i = torch.zeros_like(masked_residual, dtype=torch.float32) # (B, D, T), ensure fp32
# Quantize non-skipped samples
non_skipped_mask = ~skip_mask # (B)
if non_skipped_mask.any():
z_q_i_non_skipped, commit_loss_i_non_skipped, _, indices_i_non_skipped, z_e_i_non_skipped = quantizer(
masked_residual[non_skipped_mask].float() # Ensure fp32
)
z_q_i[non_skipped_mask] = z_q_i_non_skipped
commit_loss_i[non_skipped_mask] = commit_loss_i_non_skipped
indices_i[non_skipped_mask] = indices_i_non_skipped
z_e_i[non_skipped_mask] = z_e_i_non_skipped
else:
z_q_i, commit_loss_i, _, indices_i, z_e_i = quantizer(masked_residual.float()) # (B, D, T), (B), scalar, (B, T), (B, D', T), ensure fp32
quantizer_mask = (torch.full((z.shape[0],), i, device=z.device, dtype=torch.float32) < n_quantizers_tensor) # (B)
update_mask = (mask & quantizer_mask.unsqueeze(-1)).unsqueeze(1) # (B, 1, T)
# If skipping, output is directly the input
if skip_mask is not None:
skip_mask_expanded = skip_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1)
z_q_i = torch.where(skip_mask_expanded, masked_residual, z_q_i) # (B, D, T)
commit_loss_i = torch.where(skip_mask, torch.zeros_like(commit_loss_i), commit_loss_i) # (B)
quantized_out = quantized_out + z_q_i * update_mask # (B, D, T)
residual_fp32 = residual.to(dtype=torch.float32) # (B, D, T)
z_q_i_fp32 = z_q_i.to(dtype=torch.float32) # (B, D, T)
residual_fp32 = residual_fp32 - z_q_i_fp32 * update_mask # (B, D, T)
residual = residual_fp32.to(dtype=torch.float32) # (B, D, T), ensure fp32
valid_mask = mask & quantizer_mask.unsqueeze(-1) # (B, T)
if valid_mask.any():
commit_loss_i = (commit_loss_i * quantizer_mask).sum() / quantizer_mask.sum() # scalar
else:
commit_loss_i = torch.tensor(0.0, device=z.device, dtype=torch.float32) # scalar, ensure fp32
all_commit_losses.append(commit_loss_i) # scalar
all_indices.append(indices_i) # (B, T)
all_quantized.append(z_q_i) # (B, D, T)
all_commit_losses = torch.stack(all_commit_losses) # (N)
all_indices = torch.stack(all_indices) # (N, B, T)
all_quantized = torch.stack(all_quantized) # (N, B, D, T)
output_length = input_length # (B)
quantized_out = self.output_proj(quantized_out)
return (
quantized_out, # (B, D, T)
all_indices, # (N, B, T)
all_commit_losses,# (N)
all_quantized, # (N, B, D, T)
output_length, # (B)
)
def decode_codes(self, codes): # codes: (nq, B, T)
"""Decode codes from multiple quantizers to embeddings.
Args:
codes: Tensor of shape (nq, B, T) containing code indices for each quantizer.
Returns:
emb: Tensor of shape (B, D, T) representing the decoded embeddings.
"""
nq, B, T = codes.shape
device = codes.device
emb = torch.zeros(B, self.rvq_dim, T, device=device, dtype=torch.float32) # (B, D, T)
for i, quantizer in enumerate(self.quantizers[:nq]):
code_i = codes[i] # (B, T)
quantized_i = quantizer.decode_code(code_i) # (B, D', T)
emb += quantized_i # Accumulate quantized embeddings
emb = self.output_proj(emb) # (B, D, T), apply output projection
return emb # (B, D, T)
def ema_inplace(moving_avg, new, decay):
# moving_avg: (codebook_size) or (codebook_size, D'), new: same as moving_avg
"""Update exponential moving average in-place"""
moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) # ensure fp32