|
"""Vector quantizer. |
|
|
|
Copyright (2024) Bytedance Ltd. and/or its affiliates |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
you may not use this file except in compliance with the License. |
|
You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
See the License for the specific language governing permissions and |
|
limitations under the License. |
|
|
|
Reference: |
|
https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py |
|
https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py |
|
https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/distributions/distributions.py |
|
https://github.com/lyndonzheng/CVQ-VAE/blob/main/quantise.py |
|
""" |
|
|
|
from typing import Mapping, Text, Tuple |
|
|
|
import torch |
|
from einops import rearrange |
|
from accelerate.utils.operations import gather |
|
from torch.cuda.amp import autocast |
|
|
|
|
|
class VectorQuantizer(torch.nn.Module): |
|
def __init__( |
|
self, |
|
codebook_size: int = 1024, |
|
token_size: int = 256, |
|
commitment_cost: float = 0.25, |
|
use_l2_norm: bool = False, |
|
clustering_vq: bool = False, |
|
): |
|
super().__init__() |
|
self.codebook_size = codebook_size |
|
self.token_size = token_size |
|
self.commitment_cost = commitment_cost |
|
|
|
self.embedding = torch.nn.Embedding(codebook_size, token_size) |
|
self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) |
|
self.use_l2_norm = use_l2_norm |
|
|
|
self.clustering_vq = clustering_vq |
|
if clustering_vq: |
|
self.decay = 0.99 |
|
self.register_buffer("embed_prob", torch.zeros(self.codebook_size)) |
|
|
|
|
|
@autocast(enabled=False) |
|
def forward( |
|
self, z: torch.Tensor |
|
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: |
|
z = z.float() |
|
z = rearrange(z, "b c h w -> b h w c").contiguous() |
|
z_flattened = rearrange(z, "b h w c -> (b h w) c") |
|
unnormed_z_flattened = z_flattened |
|
|
|
if self.use_l2_norm: |
|
z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1) |
|
embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1) |
|
else: |
|
embedding = self.embedding.weight |
|
d = ( |
|
torch.sum(z_flattened**2, dim=1, keepdim=True) |
|
+ torch.sum(embedding**2, dim=1) |
|
- 2 * torch.einsum("bd,dn->bn", z_flattened, embedding.T) |
|
) |
|
|
|
min_encoding_indices = torch.argmin(d, dim=1) |
|
z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape) |
|
|
|
if self.use_l2_norm: |
|
z = torch.nn.functional.normalize(z, dim=-1) |
|
|
|
|
|
commitment_loss = self.commitment_cost * torch.mean( |
|
(z_quantized.detach() - z) ** 2 |
|
) |
|
codebook_loss = torch.mean((z_quantized - z.detach()) ** 2) |
|
|
|
if self.clustering_vq and self.training: |
|
with torch.no_grad(): |
|
|
|
encoding_indices = gather(min_encoding_indices) |
|
if len(min_encoding_indices.shape) != 1: |
|
raise ValueError( |
|
f"min_encoding_indices in a wrong shape, {min_encoding_indices.shape}" |
|
) |
|
|
|
encodings = torch.zeros( |
|
encoding_indices.shape[0], self.codebook_size, device=z.device |
|
) |
|
encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) |
|
avg_probs = torch.mean(encodings, dim=0) |
|
self.embed_prob.mul_(self.decay).add_(avg_probs, alpha=1 - self.decay) |
|
|
|
all_d = gather(d) |
|
all_unnormed_z_flattened = gather(unnormed_z_flattened).detach() |
|
if all_d.shape[0] != all_unnormed_z_flattened.shape[0]: |
|
raise ValueError( |
|
"all_d and all_unnormed_z_flattened have different length" |
|
+ f"{all_d.shape}, {all_unnormed_z_flattened.shape}" |
|
) |
|
indices = torch.argmin(all_d, dim=0) |
|
random_feat = all_unnormed_z_flattened[indices] |
|
|
|
decay = ( |
|
torch.exp( |
|
-(self.embed_prob * self.codebook_size * 10) / (1 - self.decay) |
|
- 1e-3 |
|
) |
|
.unsqueeze(1) |
|
.repeat(1, self.token_size) |
|
) |
|
self.embedding.weight.data = ( |
|
self.embedding.weight.data * (1 - decay) + random_feat * decay |
|
) |
|
|
|
loss = commitment_loss + codebook_loss |
|
|
|
|
|
z_quantized = z + (z_quantized - z).detach() |
|
|
|
|
|
z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous() |
|
|
|
result_dict = dict( |
|
quantizer_loss=loss, |
|
commitment_loss=commitment_loss, |
|
codebook_loss=codebook_loss, |
|
min_encoding_indices=min_encoding_indices.view( |
|
z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3] |
|
), |
|
) |
|
|
|
return z_quantized, result_dict |
|
|
|
def get_codebook_entry(self, indices): |
|
if len(indices.shape) == 1: |
|
z_quantized = self.embedding(indices) |
|
elif len(indices.shape) == 2: |
|
z_quantized = torch.einsum("bd,dn->bn", indices, self.embedding.weight) |
|
else: |
|
raise NotImplementedError |
|
if self.use_l2_norm: |
|
z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) |
|
return z_quantized |
|
|
|
|
|
class DiagonalGaussianDistribution(object): |
|
@autocast(enabled=False) |
|
def __init__(self, parameters, deterministic=False): |
|
"""Initializes a Gaussian distribution instance given the parameters. |
|
|
|
Args: |
|
parameters (torch.Tensor): The parameters for the Gaussian distribution. It is expected |
|
to be in shape [B, 2 * C, *], where B is batch size, and C is the embedding dimension. |
|
First C channels are used for mean and last C are used for logvar in the Gaussian distribution. |
|
deterministic (bool): Whether to use deterministic sampling. When it is true, the sampling results |
|
is purely based on mean (i.e., std = 0). |
|
""" |
|
self.parameters = parameters |
|
self.mean, self.logvar = torch.chunk(parameters.float(), 2, dim=1) |
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0) |
|
self.deterministic = deterministic |
|
self.std = torch.exp(0.5 * self.logvar) |
|
self.var = torch.exp(self.logvar) |
|
if self.deterministic: |
|
self.var = self.std = torch.zeros_like(self.mean).to( |
|
device=self.parameters.device |
|
) |
|
|
|
@autocast(enabled=False) |
|
def sample(self): |
|
x = self.mean.float() + self.std.float() * torch.randn(self.mean.shape).to( |
|
device=self.parameters.device |
|
) |
|
return x |
|
|
|
@autocast(enabled=False) |
|
def mode(self): |
|
return self.mean |
|
|
|
@autocast(enabled=False) |
|
def kl(self): |
|
if self.deterministic: |
|
return torch.Tensor([0.0]) |
|
else: |
|
return 0.5 * torch.sum( |
|
torch.pow(self.mean.float(), 2) |
|
+ self.var.float() |
|
- 1.0 |
|
- self.logvar.float(), |
|
dim=[1, 2], |
|
) |
|
|