File size: 18,230 Bytes
ea174b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 |
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from einops import rearrange
from torch.nn.utils import weight_norm
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
def sample_vectors(samples, num):
# samples: (N, D), num_samples: N, feature dim: D
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices].float() # (num, D), ensure fp32
def kmeans(samples, num_clusters, num_iters=10):
# samples: (N, D), N samples with D dimensions
dim, dtype = samples.shape[-1], torch.float32 # Force fp32
means = sample_vectors(samples, num_clusters).float() # (num_clusters, D), ensure fp32
for _ in range(num_iters):
dists = -(samples.float().pow(2).sum(1, keepdim=True) - # (N, 1), ensure fp32
2 * samples.float() @ means.t() + # (N, num_clusters), ensure fp32
means.t().float().pow(2).sum(0, keepdim=True)) # (1, num_clusters), ensure fp32
# dists: (N, num_clusters)
buckets = dists.max(dim=-1).indices # (N)
bins = torch.bincount(buckets, minlength=num_clusters) # (num_clusters)
zero_mask = bins == 0 # (num_clusters)
bins_min_clamped = bins.masked_fill(zero_mask, 1) # (num_clusters)
new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32) # (num_clusters, D), ensure fp32
new_means.scatter_add_(0, buckets.unsqueeze(1).expand(-1, dim), samples.float()) # (num_clusters, D), ensure fp32
new_means = new_means / bins_min_clamped[..., None] # (num_clusters, D)
means = torch.where(zero_mask[..., None], means, new_means) # (num_clusters, D)
# Final cluster assignments for returning cluster sizes
dists = -(samples.float().pow(2).sum(1, keepdim=True) -
2 * samples.float() @ means.t() +
means.t().float().pow(2).sum(0, keepdim=True)) # (N, num_clusters), ensure fp32
buckets = dists.max(dim=-1).indices # (N)
bins = torch.bincount(buckets, minlength=num_clusters).float() # (num_clusters), ensure fp32
return means, bins # (num_clusters, D), (num_clusters)
class VectorQuantize(nn.Module):
def __init__(
self,
input_dim,
codebook_size,
codebook_dim,
commitment=1.0,
decay=0.99, # EMA decay
epsilon=1e-5, # Laplace smoothing epsilon
threshold_ema_dead=2, # Dead code threshold
kmeans_init=True, # Use kmeans initialization
kmeans_iters=10, # Kmeans iterations
):
super().__init__()
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment = commitment
self.decay = decay
self.epsilon = epsilon
self.threshold_ema_dead = threshold_ema_dead
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
if self.input_dim != self.codebook_dim:
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) # (B, D, T) -> (B, D', T)
self.out_project = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1) # (B, D', T) -> (B, D, T)
else:
self.in_project = nn.Identity()
self.out_project = nn.Identity()
# Initialize codebook and EMA buffers
init_fn = torch.zeros if kmeans_init else lambda x, y: torch.randn(x, y)
self.register_buffer("codebook", init_fn(codebook_size, codebook_dim).float()) # (codebook_size, D'), ensure fp32
self.register_buffer("inited", torch.tensor([not kmeans_init], dtype=torch.bool)) # (1)
self.register_buffer("cluster_size", torch.zeros(codebook_size).float()) # (codebook_size), ensure fp32
self.register_buffer("embed_avg", self.codebook.clone().float()) # (codebook_size, D'), ensure fp32
def ema_update(self, encodings, embed_onehot):
# encodings: (B*T, D'), embed_onehot: (B*T, codebook_size)
"""Update codebook using EMA"""
encodings = encodings.float() # Ensure fp32
embed_onehot = embed_onehot.float() # Ensure fp32
cluster_size_new = embed_onehot.sum(0) # (codebook_size)
embed_sum = encodings.t() @ embed_onehot # (D', codebook_size)
# Distributed reduction
if dist.is_initialized():
dist.all_reduce(cluster_size_new, op=dist.ReduceOp.SUM)
dist.all_reduce(embed_sum, op=dist.ReduceOp.SUM)
ema_inplace(self.cluster_size, cluster_size_new, self.decay) # (codebook_size)
ema_inplace(self.embed_avg, embed_sum.t(), self.decay) # (codebook_size, D')
# Laplace smoothing
cluster_size = (self.cluster_size + self.epsilon) / (self.cluster_size.sum() + self.codebook_size * self.epsilon) # (codebook_size)
cluster_size = cluster_size * self.cluster_size.sum() # (codebook_size)
self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1)) # (codebook_size, D')
def replace_dead_codes(self, encodings):
# encodings: (B*T, D')
"""Replace dead codes with random samples from current batch"""
if self.threshold_ema_dead == 0:
return
dead_mask = self.cluster_size < self.threshold_ema_dead # (codebook_size)
if dead_mask.any():
if dist.is_initialized() and dist.get_rank() == 0:
samples = sample_vectors(encodings.float(), self.codebook_size) # (codebook_size, D'), ensure fp32
else:
samples = torch.zeros_like(self.codebook).float() # Placeholder, ensure fp32
# Broadcast samples
if dist.is_initialized():
dist.broadcast(samples, src=0)
self.codebook[dead_mask] = samples[:dead_mask.sum()].to(self.codebook.dtype) # Update dead codes
def init_codebook(self, encodings):
# encodings: (B*T, D')
"""Initialize codebook with k-means and update cluster_size"""
if self.inited.item():
return
if dist.is_initialized() and dist.get_rank() == 0:
embed, cluster_sizes = kmeans(encodings.float(), self.codebook_size, self.kmeans_iters) # (codebook_size, D'), (codebook_size), ensure fp32
else:
embed = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device).float() # ensure fp32
cluster_sizes = torch.zeros(self.codebook_size, device=encodings.device, dtype=torch.float32) # ensure fp32
# Broadcast results
if dist.is_initialized():
dist.broadcast(embed, src=0)
dist.broadcast(cluster_sizes, src=0)
self.codebook.copy_(embed) # (codebook_size, D')
self.embed_avg.copy_(embed.clone()) # (codebook_size, D')
self.cluster_size.copy_(cluster_sizes.float()) # (codebook_size)
self.inited.fill_(True)
def forward(self, z): # z: (B, D, T)
# logging.info(f"{self.cluster_size = }, {self.codebook = }, {self.embed_avg = }, {self.inited = }")
z = z.float() # Ensure fp32
z_e = self.in_project(z).float() # (B, D', T), ensure fp32
# Rearrange for quantization
encodings = rearrange(z_e, "b d t -> (b t) d").float() # (B*T, D'), ensure fp32
# Initialize codebook if needed
if self.kmeans_init and not self.inited.item():
self.init_codebook(encodings)
# Quantization
dist = (encodings.pow(2).sum(1, keepdim=True) - # (B*T, 1)
2 * encodings @ self.codebook.float().t() + # (B*T, codebook_size)
self.codebook.float().pow(2).sum(1, keepdim=True).t()) # (1, codebook_size)
# dist: (B*T, codebook_size)
indices = (-dist).max(1)[1] # (B*T)
indices = rearrange(indices, "(b t) -> b t", b=z.size(0)) # (B, T)
# Get quantized vectors
z_q = self.decode_code(indices).float() # (B, D', T), ensure fp32
# Commitment loss
commit_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) * self.commitment # (B)
# EMA updates and dead code replacement during training
if self.training and torch.is_grad_enabled():
embed_onehot = F.one_hot(indices.view(-1), self.codebook_size).float() # (B*T, codebook_size), ensure fp32
self.ema_update(encodings, embed_onehot)
self.replace_dead_codes(encodings)
# Straight-through estimator
z_q = z_e + (z_q - z_e).detach() # (B, D', T)
z_q = self.out_project(z_q).float() # (B, D, T), ensure fp32
return z_q, commit_loss, torch.tensor(0.0, device=z.device, dtype=torch.float32), indices, z # (B, D, T), (B), scalar, (B, T), (B, D', T)
def decode_code(self, embed_id): # embed_id: (B, T)
return F.embedding(embed_id, self.codebook).transpose(1, 2).float() # (B, D', T), ensure fp32
class ResidualVQ(nn.Module):
def __init__(
self,
input_dim: int = 1280, # Input dimension, unrelated to RVQ
rvq_dim = None, # RVQ dimension. If different from input_dim/output_dim, will add input_dim->rvq_dim/rvq_dim->output_dim projection
output_dim: int = None, # Output dimension, unrelated to RVQ
num_quantizers: int = 32,
codebook_size: int = 1024,
codebook_dim: int = 8, # Dimension of each codebook. If different from rvq_dim, will add rvq_dim->codebook_dim and codebook_dim->rvq_dim projections
quantizer_dropout: float = 0.5,
decay=0.99,
epsilon=1e-5,
threshold_ema_dead=2,
kmeans_init=True,
kmeans_iters=10,
skip_rvq_ratio: float = 0.0, # New parameter: probability of skipping RVQ
**kwargs,
):
super().__init__()
self.input_dim = input_dim
self.num_quantizers = num_quantizers
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer_dropout = quantizer_dropout
self.skip_rvq_ratio = skip_rvq_ratio # Store skip probability
self.rvq_dim = rvq_dim
self.input_proj = WNConv1d(input_dim, rvq_dim, kernel_size=1) if input_dim != rvq_dim else nn.Identity()
self.output_proj = WNConv1d(rvq_dim, output_dim, kernel_size=1) if rvq_dim != output_dim else nn.Identity()
self.quantizers = nn.ModuleList(
[
VectorQuantize(
input_dim=rvq_dim,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
decay=decay,
epsilon=epsilon,
threshold_ema_dead=threshold_ema_dead,
kmeans_init=kmeans_init,
kmeans_iters=kmeans_iters,
**kwargs,
)
for _ in range(num_quantizers)
]
)
def forward(self, z, input_length, n_quantizers: int = None): # z: (B, D, T), input_length: (B)
z = self.input_proj(z)
with torch.autocast('cuda', enabled = False):
batch_size, _, max_time = z.shape
mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) # (B, T)
quantized_out = torch.zeros_like(z, dtype=torch.float32) # (B, D, T), ensure fp32
residual = z.clone().float() # (B, D, T), ensure fp32
all_commit_losses = []
all_indices = []
all_quantized = []
n_quantizers = n_quantizers or self.num_quantizers
# Randomly decide whether to skip RVQ during training
skip_mask = None
if self.training and torch.is_grad_enabled() and self.skip_rvq_ratio > 0:
# Generate random mask with skip_rvq_ratio probability
skip_mask = torch.rand(batch_size, device=z.device) < self.skip_rvq_ratio # (B,)
# If all samples are skipped, force the first sample to be unskipped
if skip_mask.all():
skip_mask[0] = False # Ensure at least one sample (index 0) is not skipped
if self.training and torch.is_grad_enabled():
n_quantizers_tensor = torch.ones((z.shape[0],), dtype=torch.float32, device=z.device) * self.num_quantizers + 1 # (B)
dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],), dtype=torch.float32, device=z.device) # (B)
n_dropout = int(z.shape[0] * self.quantizer_dropout)
n_quantizers_tensor[:n_dropout] = dropout[:n_dropout] # (B)
else:
n_quantizers_tensor = torch.full((z.shape[0],), n_quantizers, dtype=torch.float32, device=z.device) # (B)
for i, quantizer in enumerate(self.quantizers):
if not self.training and i >= n_quantizers:
break
masked_residual = residual * mask.unsqueeze(1) # (B, D, T)
# If skipping RVQ, directly use input value
if self.training and skip_mask is not None and skip_mask.any():
z_q_i = torch.zeros_like(masked_residual, dtype=torch.float32) # (B, D, T), ensure fp32
commit_loss_i = torch.zeros(batch_size, device=z.device, dtype=torch.float32) # (B), ensure fp32
indices_i = torch.zeros(batch_size, max_time, device=z.device, dtype=torch.long) # (B, T)
z_e_i = torch.zeros_like(masked_residual, dtype=torch.float32) # (B, D, T), ensure fp32
# Quantize non-skipped samples
non_skipped_mask = ~skip_mask # (B)
if non_skipped_mask.any():
z_q_i_non_skipped, commit_loss_i_non_skipped, _, indices_i_non_skipped, z_e_i_non_skipped = quantizer(
masked_residual[non_skipped_mask].float() # Ensure fp32
)
z_q_i[non_skipped_mask] = z_q_i_non_skipped
commit_loss_i[non_skipped_mask] = commit_loss_i_non_skipped
indices_i[non_skipped_mask] = indices_i_non_skipped
z_e_i[non_skipped_mask] = z_e_i_non_skipped
else:
z_q_i, commit_loss_i, _, indices_i, z_e_i = quantizer(masked_residual.float()) # (B, D, T), (B), scalar, (B, T), (B, D', T), ensure fp32
quantizer_mask = (torch.full((z.shape[0],), i, device=z.device, dtype=torch.float32) < n_quantizers_tensor) # (B)
update_mask = (mask & quantizer_mask.unsqueeze(-1)).unsqueeze(1) # (B, 1, T)
# If skipping, output is directly the input
if skip_mask is not None:
skip_mask_expanded = skip_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1)
z_q_i = torch.where(skip_mask_expanded, masked_residual, z_q_i) # (B, D, T)
commit_loss_i = torch.where(skip_mask, torch.zeros_like(commit_loss_i), commit_loss_i) # (B)
quantized_out = quantized_out + z_q_i * update_mask # (B, D, T)
residual_fp32 = residual.to(dtype=torch.float32) # (B, D, T)
z_q_i_fp32 = z_q_i.to(dtype=torch.float32) # (B, D, T)
residual_fp32 = residual_fp32 - z_q_i_fp32 * update_mask # (B, D, T)
residual = residual_fp32.to(dtype=torch.float32) # (B, D, T), ensure fp32
valid_mask = mask & quantizer_mask.unsqueeze(-1) # (B, T)
if valid_mask.any():
commit_loss_i = (commit_loss_i * quantizer_mask).sum() / quantizer_mask.sum() # scalar
else:
commit_loss_i = torch.tensor(0.0, device=z.device, dtype=torch.float32) # scalar, ensure fp32
all_commit_losses.append(commit_loss_i) # scalar
all_indices.append(indices_i) # (B, T)
all_quantized.append(z_q_i) # (B, D, T)
all_commit_losses = torch.stack(all_commit_losses) # (N)
all_indices = torch.stack(all_indices) # (N, B, T)
all_quantized = torch.stack(all_quantized) # (N, B, D, T)
output_length = input_length # (B)
quantized_out = self.output_proj(quantized_out)
return (
quantized_out, # (B, D, T)
all_indices, # (N, B, T)
all_commit_losses,# (N)
all_quantized, # (N, B, D, T)
output_length, # (B)
)
def decode_codes(self, codes): # codes: (nq, B, T)
"""Decode codes from multiple quantizers to embeddings.
Args:
codes: Tensor of shape (nq, B, T) containing code indices for each quantizer.
Returns:
emb: Tensor of shape (B, D, T) representing the decoded embeddings.
"""
nq, B, T = codes.shape
device = codes.device
emb = torch.zeros(B, self.rvq_dim, T, device=device, dtype=torch.float32) # (B, D, T)
for i, quantizer in enumerate(self.quantizers[:nq]):
code_i = codes[i] # (B, T)
quantized_i = quantizer.decode_code(code_i) # (B, D', T)
emb += quantized_i # Accumulate quantized embeddings
emb = self.output_proj(emb) # (B, D, T), apply output projection
return emb # (B, D, T)
def ema_inplace(moving_avg, new, decay):
# moving_avg: (codebook_size) or (codebook_size, D'), new: same as moving_avg
"""Update exponential moving average in-place"""
moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) # ensure fp32 |