File size: 18,230 Bytes
ea174b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist

from einops import rearrange
from torch.nn.utils import weight_norm

def WNConv1d(*args, **kwargs):
    return weight_norm(nn.Conv1d(*args, **kwargs))

def WNConvTranspose1d(*args, **kwargs):
    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))

def sample_vectors(samples, num):
    # samples: (N, D), num_samples: N, feature dim: D
    num_samples, device = samples.shape[0], samples.device
    if num_samples >= num:
        indices = torch.randperm(num_samples, device=device)[:num]
    else:
        indices = torch.randint(0, num_samples, (num,), device=device)
    return samples[indices].float()  # (num, D), ensure fp32

def kmeans(samples, num_clusters, num_iters=10):
    # samples: (N, D), N samples with D dimensions
    dim, dtype = samples.shape[-1], torch.float32  # Force fp32
    means = sample_vectors(samples, num_clusters).float()  # (num_clusters, D), ensure fp32
    
    for _ in range(num_iters):
        dists = -(samples.float().pow(2).sum(1, keepdim=True) -  # (N, 1), ensure fp32
                 2 * samples.float() @ means.t() +               # (N, num_clusters), ensure fp32
                 means.t().float().pow(2).sum(0, keepdim=True))  # (1, num_clusters), ensure fp32
        # dists: (N, num_clusters)
        buckets = dists.max(dim=-1).indices  # (N)
        bins = torch.bincount(buckets, minlength=num_clusters)  # (num_clusters)
        zero_mask = bins == 0  # (num_clusters)
        bins_min_clamped = bins.masked_fill(zero_mask, 1)  # (num_clusters)
        
        new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32)  # (num_clusters, D), ensure fp32
        new_means.scatter_add_(0, buckets.unsqueeze(1).expand(-1, dim), samples.float())  # (num_clusters, D), ensure fp32
        new_means = new_means / bins_min_clamped[..., None]  # (num_clusters, D)
        means = torch.where(zero_mask[..., None], means, new_means)  # (num_clusters, D)
    
    # Final cluster assignments for returning cluster sizes
    dists = -(samples.float().pow(2).sum(1, keepdim=True) - 
             2 * samples.float() @ means.t() + 
             means.t().float().pow(2).sum(0, keepdim=True))  # (N, num_clusters), ensure fp32
    buckets = dists.max(dim=-1).indices  # (N)
    bins = torch.bincount(buckets, minlength=num_clusters).float()  # (num_clusters), ensure fp32
    
    return means, bins  # (num_clusters, D), (num_clusters)

class VectorQuantize(nn.Module):
    def __init__(
        self,
        input_dim,
        codebook_size,
        codebook_dim,
        commitment=1.0,
        decay=0.99,           # EMA decay
        epsilon=1e-5,         # Laplace smoothing epsilon
        threshold_ema_dead=2, # Dead code threshold
        kmeans_init=True,     # Use kmeans initialization
        kmeans_iters=10,      # Kmeans iterations
    ):
        super().__init__()
        self.input_dim = input_dim
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        self.commitment = commitment
        self.decay = decay
        self.epsilon = epsilon
        self.threshold_ema_dead = threshold_ema_dead
        self.kmeans_init = kmeans_init
        self.kmeans_iters = kmeans_iters
        
        if self.input_dim != self.codebook_dim:
            self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)  # (B, D, T) -> (B, D', T)
            self.out_project = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1)  # (B, D', T) -> (B, D, T)
        else:
            self.in_project = nn.Identity()
            self.out_project = nn.Identity()

        # Initialize codebook and EMA buffers
        init_fn = torch.zeros if kmeans_init else lambda x, y: torch.randn(x, y)
        self.register_buffer("codebook", init_fn(codebook_size, codebook_dim).float())  # (codebook_size, D'), ensure fp32
        self.register_buffer("inited", torch.tensor([not kmeans_init], dtype=torch.bool))  # (1)
        self.register_buffer("cluster_size", torch.zeros(codebook_size).float())  # (codebook_size), ensure fp32
        self.register_buffer("embed_avg", self.codebook.clone().float())  # (codebook_size, D'), ensure fp32

    def ema_update(self, encodings, embed_onehot):
        # encodings: (B*T, D'), embed_onehot: (B*T, codebook_size)
        """Update codebook using EMA"""
        encodings = encodings.float()  # Ensure fp32
        embed_onehot = embed_onehot.float()  # Ensure fp32
        cluster_size_new = embed_onehot.sum(0)  # (codebook_size)
        embed_sum = encodings.t() @ embed_onehot  # (D', codebook_size)
        
        # Distributed reduction
        if dist.is_initialized():
            dist.all_reduce(cluster_size_new, op=dist.ReduceOp.SUM)
            dist.all_reduce(embed_sum, op=dist.ReduceOp.SUM)
            
        ema_inplace(self.cluster_size, cluster_size_new, self.decay)  # (codebook_size)
        ema_inplace(self.embed_avg, embed_sum.t(), self.decay)  # (codebook_size, D')
        
        # Laplace smoothing
        cluster_size = (self.cluster_size + self.epsilon) / (self.cluster_size.sum() + self.codebook_size * self.epsilon)  # (codebook_size)
        cluster_size = cluster_size * self.cluster_size.sum()  # (codebook_size)
        self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1))  # (codebook_size, D')

    def replace_dead_codes(self, encodings):
        # encodings: (B*T, D')
        """Replace dead codes with random samples from current batch"""
        if self.threshold_ema_dead == 0:
            return
        
        dead_mask = self.cluster_size < self.threshold_ema_dead  # (codebook_size)
        if dead_mask.any():
            if dist.is_initialized() and dist.get_rank() == 0:
                samples = sample_vectors(encodings.float(), self.codebook_size)  # (codebook_size, D'), ensure fp32
            else:
                samples = torch.zeros_like(self.codebook).float()  # Placeholder, ensure fp32
            
            # Broadcast samples
            if dist.is_initialized():
                dist.broadcast(samples, src=0)
            
            self.codebook[dead_mask] = samples[:dead_mask.sum()].to(self.codebook.dtype)  # Update dead codes

    def init_codebook(self, encodings):
        # encodings: (B*T, D')
        """Initialize codebook with k-means and update cluster_size"""
        if self.inited.item():
            return
        
        if dist.is_initialized() and dist.get_rank() == 0:
            embed, cluster_sizes = kmeans(encodings.float(), self.codebook_size, self.kmeans_iters)  # (codebook_size, D'), (codebook_size), ensure fp32
        else:
            embed = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device).float()  # ensure fp32
            cluster_sizes = torch.zeros(self.codebook_size, device=encodings.device, dtype=torch.float32)  # ensure fp32
        
        # Broadcast results
        if dist.is_initialized():
            dist.broadcast(embed, src=0)
            dist.broadcast(cluster_sizes, src=0)
            
        self.codebook.copy_(embed)  # (codebook_size, D')
        self.embed_avg.copy_(embed.clone())  # (codebook_size, D')
        self.cluster_size.copy_(cluster_sizes.float())  # (codebook_size)
        self.inited.fill_(True)

    def forward(self, z):  # z: (B, D, T)
        # logging.info(f"{self.cluster_size = }, {self.codebook = }, {self.embed_avg = }, {self.inited = }")
        z = z.float()  # Ensure fp32
        z_e = self.in_project(z).float()  # (B, D', T), ensure fp32
        
        # Rearrange for quantization
        encodings = rearrange(z_e, "b d t -> (b t) d").float()  # (B*T, D'), ensure fp32
        
        # Initialize codebook if needed
        if self.kmeans_init and not self.inited.item():
            self.init_codebook(encodings)

        # Quantization
        dist = (encodings.pow(2).sum(1, keepdim=True) -  # (B*T, 1)
            2 * encodings @ self.codebook.float().t() +       # (B*T, codebook_size)
            self.codebook.float().pow(2).sum(1, keepdim=True).t())  # (1, codebook_size)
        # dist: (B*T, codebook_size)
        
        indices = (-dist).max(1)[1]  # (B*T)
        indices = rearrange(indices, "(b t) -> b t", b=z.size(0))  # (B, T)
        
        # Get quantized vectors
        z_q = self.decode_code(indices).float()  # (B, D', T), ensure fp32
        
        # Commitment loss
        commit_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) * self.commitment  # (B)
        
        # EMA updates and dead code replacement during training
        if self.training and torch.is_grad_enabled():
            embed_onehot = F.one_hot(indices.view(-1), self.codebook_size).float()  # (B*T, codebook_size), ensure fp32
            self.ema_update(encodings, embed_onehot)
            self.replace_dead_codes(encodings)

        # Straight-through estimator
        z_q = z_e + (z_q - z_e).detach()  # (B, D', T)
        z_q = self.out_project(z_q).float()  # (B, D, T), ensure fp32

        return z_q, commit_loss, torch.tensor(0.0, device=z.device, dtype=torch.float32), indices, z  # (B, D, T), (B), scalar, (B, T), (B, D', T)

    def decode_code(self, embed_id):  # embed_id: (B, T)
        return F.embedding(embed_id, self.codebook).transpose(1, 2).float()  # (B, D', T), ensure fp32

class ResidualVQ(nn.Module):
    def __init__(
        self,
        input_dim: int = 1280, # Input dimension, unrelated to RVQ
        rvq_dim = None, # RVQ dimension. If different from input_dim/output_dim, will add input_dim->rvq_dim/rvq_dim->output_dim projection
        output_dim: int = None, # Output dimension, unrelated to RVQ
        num_quantizers: int = 32,
        codebook_size: int = 1024,
        codebook_dim: int = 8, # Dimension of each codebook. If different from rvq_dim, will add rvq_dim->codebook_dim and codebook_dim->rvq_dim projections
        quantizer_dropout: float = 0.5,
        decay=0.99,
        epsilon=1e-5,
        threshold_ema_dead=2,
        kmeans_init=True,
        kmeans_iters=10,
        skip_rvq_ratio: float = 0.0,  # New parameter: probability of skipping RVQ
        **kwargs,
    ):
        super().__init__()
        self.input_dim = input_dim
        
        self.num_quantizers = num_quantizers
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        self.quantizer_dropout = quantizer_dropout
        self.skip_rvq_ratio = skip_rvq_ratio  # Store skip probability
        self.rvq_dim = rvq_dim
        
        self.input_proj = WNConv1d(input_dim, rvq_dim, kernel_size=1) if input_dim != rvq_dim else nn.Identity()
        self.output_proj = WNConv1d(rvq_dim, output_dim, kernel_size=1) if rvq_dim != output_dim else nn.Identity()

        self.quantizers = nn.ModuleList(
            [
                VectorQuantize(
                    input_dim=rvq_dim,
                    codebook_size=codebook_size,
                    codebook_dim=codebook_dim,
                    decay=decay,
                    epsilon=epsilon,
                    threshold_ema_dead=threshold_ema_dead,
                    kmeans_init=kmeans_init,
                    kmeans_iters=kmeans_iters,
                    **kwargs,
                )
                for _ in range(num_quantizers)
            ]
        )

    def forward(self, z, input_length, n_quantizers: int = None):  # z: (B, D, T), input_length: (B)
        z = self.input_proj(z)

        with torch.autocast('cuda', enabled = False):    
            batch_size, _, max_time = z.shape
            mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1)  # (B, T)

            quantized_out = torch.zeros_like(z, dtype=torch.float32)  # (B, D, T), ensure fp32
            residual = z.clone().float()  # (B, D, T), ensure fp32

            all_commit_losses = []
            all_indices = []
            all_quantized = []

            n_quantizers = n_quantizers or self.num_quantizers

            # Randomly decide whether to skip RVQ during training
            skip_mask = None
            if self.training and torch.is_grad_enabled() and self.skip_rvq_ratio > 0:
                # Generate random mask with skip_rvq_ratio probability
                skip_mask = torch.rand(batch_size, device=z.device) < self.skip_rvq_ratio  # (B,)
                # If all samples are skipped, force the first sample to be unskipped
                if skip_mask.all():
                    skip_mask[0] = False  # Ensure at least one sample (index 0) is not skipped

            if self.training and torch.is_grad_enabled():
                n_quantizers_tensor = torch.ones((z.shape[0],), dtype=torch.float32, device=z.device) * self.num_quantizers + 1  # (B)
                dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],), dtype=torch.float32, device=z.device)  # (B)
                n_dropout = int(z.shape[0] * self.quantizer_dropout)
                n_quantizers_tensor[:n_dropout] = dropout[:n_dropout]  # (B)
            else:
                n_quantizers_tensor = torch.full((z.shape[0],), n_quantizers, dtype=torch.float32, device=z.device)  # (B)

            for i, quantizer in enumerate(self.quantizers):
                if not self.training and i >= n_quantizers:
                    break

                masked_residual = residual * mask.unsqueeze(1)  # (B, D, T)
                
                # If skipping RVQ, directly use input value
                if self.training and skip_mask is not None and skip_mask.any():
                    z_q_i = torch.zeros_like(masked_residual, dtype=torch.float32)  # (B, D, T), ensure fp32
                    commit_loss_i = torch.zeros(batch_size, device=z.device, dtype=torch.float32)  # (B), ensure fp32
                    indices_i = torch.zeros(batch_size, max_time, device=z.device, dtype=torch.long)  # (B, T)
                    z_e_i = torch.zeros_like(masked_residual, dtype=torch.float32)  # (B, D, T), ensure fp32
                    
                    # Quantize non-skipped samples
                    non_skipped_mask = ~skip_mask  # (B)
                    if non_skipped_mask.any():
                        z_q_i_non_skipped, commit_loss_i_non_skipped, _, indices_i_non_skipped, z_e_i_non_skipped = quantizer(
                            masked_residual[non_skipped_mask].float()  # Ensure fp32
                        )
                        z_q_i[non_skipped_mask] = z_q_i_non_skipped
                        commit_loss_i[non_skipped_mask] = commit_loss_i_non_skipped
                        indices_i[non_skipped_mask] = indices_i_non_skipped
                        z_e_i[non_skipped_mask] = z_e_i_non_skipped
                else:
                    z_q_i, commit_loss_i, _, indices_i, z_e_i = quantizer(masked_residual.float())  # (B, D, T), (B), scalar, (B, T), (B, D', T), ensure fp32

                quantizer_mask = (torch.full((z.shape[0],), i, device=z.device, dtype=torch.float32) < n_quantizers_tensor)  # (B)
                update_mask = (mask & quantizer_mask.unsqueeze(-1)).unsqueeze(1)  # (B, 1, T)
                
                # If skipping, output is directly the input
                if skip_mask is not None:
                    skip_mask_expanded = skip_mask.unsqueeze(1).unsqueeze(2)  # (B, 1, 1)
                    z_q_i = torch.where(skip_mask_expanded, masked_residual, z_q_i)  # (B, D, T)
                    commit_loss_i = torch.where(skip_mask, torch.zeros_like(commit_loss_i), commit_loss_i)  # (B)

                quantized_out = quantized_out + z_q_i * update_mask  # (B, D, T)

                residual_fp32 = residual.to(dtype=torch.float32)  # (B, D, T)
                z_q_i_fp32 = z_q_i.to(dtype=torch.float32)  # (B, D, T)
                residual_fp32 = residual_fp32 - z_q_i_fp32 * update_mask  # (B, D, T)
                residual = residual_fp32.to(dtype=torch.float32)  # (B, D, T), ensure fp32

                valid_mask = mask & quantizer_mask.unsqueeze(-1)  # (B, T)
                if valid_mask.any():
                    commit_loss_i = (commit_loss_i * quantizer_mask).sum() / quantizer_mask.sum()  # scalar
                else:
                    commit_loss_i = torch.tensor(0.0, device=z.device, dtype=torch.float32)  # scalar, ensure fp32

                all_commit_losses.append(commit_loss_i)  # scalar
                all_indices.append(indices_i)  # (B, T)
                all_quantized.append(z_q_i)  # (B, D, T)

            all_commit_losses = torch.stack(all_commit_losses)  # (N)
            all_indices = torch.stack(all_indices)  # (N, B, T)
            all_quantized = torch.stack(all_quantized)  # (N, B, D, T)

            output_length = input_length  # (B)

        quantized_out = self.output_proj(quantized_out)

        return (
            quantized_out,    # (B, D, T)
            all_indices,      # (N, B, T)
            all_commit_losses,# (N)
            all_quantized,    # (N, B, D, T)
            output_length,    # (B)
        )

    def decode_codes(self, codes):  # codes: (nq, B, T)
        """Decode codes from multiple quantizers to embeddings.
        
        Args:
            codes: Tensor of shape (nq, B, T) containing code indices for each quantizer.

        Returns:
            emb: Tensor of shape (B, D, T) representing the decoded embeddings.
        """
        nq, B, T = codes.shape
        device = codes.device
        emb = torch.zeros(B, self.rvq_dim, T, device=device, dtype=torch.float32)  # (B, D, T)

        for i, quantizer in enumerate(self.quantizers[:nq]):
            code_i = codes[i]  # (B, T)
            quantized_i = quantizer.decode_code(code_i)  # (B, D', T)
            emb += quantized_i  # Accumulate quantized embeddings

        emb = self.output_proj(emb)  # (B, D, T), apply output projection
        return emb # (B, D, T)


def ema_inplace(moving_avg, new, decay):
    # moving_avg: (codebook_size) or (codebook_size, D'), new: same as moving_avg
    """Update exponential moving average in-place"""
    moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay))  # ensure fp32