|
""" |
|
Code is from https://github.com/TencentARC/SEED-Voken. Thanks! |
|
|
|
Lookup Free Quantization |
|
Proposed in https://arxiv.org/abs/2310.05737 |
|
|
|
In the simplest setup, each dimension is quantized into {-1, 1}. |
|
An entropy penalty is used to encourage utilization. |
|
|
|
Refer to |
|
https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py |
|
https://github.com/theAdamColton/ijepa-enhanced/blob/7edef5f7288ae8f537f0db8a10044a2a487f70c9/ijepa_enhanced/lfq.py |
|
""" |
|
|
|
from collections import namedtuple |
|
from math import ceil, log2 |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from einops import pack, rearrange, reduce, unpack |
|
from torch import einsum |
|
from torch.nn import Module |
|
|
|
|
|
|
|
LossBreakdown = namedtuple( |
|
"LossBreakdown", |
|
["per_sample_entropy", "codebook_entropy", "commitment", "avg_probs"], |
|
) |
|
|
|
|
|
|
|
|
|
def exists(v): |
|
return v is not None |
|
|
|
|
|
def default(*args): |
|
for arg in args: |
|
if exists(arg): |
|
return arg() if callable(arg) else arg |
|
return None |
|
|
|
|
|
def pack_one(t, pattern): |
|
return pack([t], pattern) |
|
|
|
|
|
def unpack_one(t, ps, pattern): |
|
return unpack(t, ps, pattern)[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def entropy(prob): |
|
return (-prob * torch.log(prob + 1e-5)).sum(dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
def mult_along_first_dims(x, y): |
|
""" |
|
returns x * y elementwise along the leading dimensions of y |
|
""" |
|
ndim_to_expand = x.ndim - y.ndim |
|
for _ in range(ndim_to_expand): |
|
y = y.unsqueeze(-1) |
|
return x * y |
|
|
|
|
|
def masked_mean(x, m): |
|
""" |
|
takes the mean of the elements of x that are not masked |
|
the mean is taken along the shared leading dims of m |
|
equivalent to: x[m].mean(tuple(range(m.ndim))) |
|
|
|
The benefit of using masked_mean rather than using |
|
tensor indexing is that masked_mean is much faster |
|
for torch-compile on batches. |
|
|
|
The drawback is larger floating point errors |
|
""" |
|
x = mult_along_first_dims(x, m) |
|
x = x / m.sum() |
|
return x.sum(tuple(range(m.ndim))) |
|
|
|
|
|
def entropy_loss( |
|
logits, |
|
mask=None, |
|
|
|
sample_minimization_weight=1.0, |
|
batch_maximization_weight=1.0, |
|
eps=1e-5, |
|
): |
|
""" |
|
Entropy loss of unnormalized logits |
|
|
|
logits: Affinities are over the last dimension |
|
|
|
https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279 |
|
LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024) |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
temperature = 0.1 |
|
probs = F.softmax(logits / temperature, -1) |
|
log_probs = F.log_softmax(logits / temperature + eps, -1) |
|
|
|
if mask is not None: |
|
|
|
|
|
|
|
avg_probs = masked_mean(probs, mask) |
|
|
|
else: |
|
avg_probs = reduce(probs, "... D -> D", "mean") |
|
|
|
avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps)) |
|
|
|
sample_entropy = -torch.sum(probs * log_probs, -1) |
|
if mask is not None: |
|
|
|
sample_entropy = masked_mean(sample_entropy, mask).mean() |
|
else: |
|
sample_entropy = torch.mean(sample_entropy) |
|
|
|
loss = (sample_minimization_weight * sample_entropy) - ( |
|
batch_maximization_weight * avg_entropy |
|
) |
|
|
|
return sample_entropy, avg_entropy, loss |
|
|
|
|
|
class LFQ(Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim=None, |
|
codebook_size=None, |
|
num_codebooks=1, |
|
sample_minimization_weight=1.0, |
|
batch_maximization_weight=1.0, |
|
token_factorization=False, |
|
factorized_bits=[9, 9], |
|
): |
|
super().__init__() |
|
|
|
|
|
|
|
assert exists(dim) or exists( |
|
codebook_size |
|
), "either dim or codebook_size must be specified for LFQ" |
|
assert ( |
|
not exists(codebook_size) or log2(codebook_size).is_integer() |
|
), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})" |
|
|
|
self.codebook_size = default(codebook_size, lambda: 2**dim) |
|
self.codebook_dim = int(log2(codebook_size)) |
|
|
|
codebook_dims = self.codebook_dim * num_codebooks |
|
dim = default(dim, codebook_dims) |
|
|
|
has_projections = dim != codebook_dims |
|
self.has_projections = has_projections |
|
|
|
self.dim = dim |
|
self.codebook_dim = self.codebook_dim |
|
self.num_codebooks = num_codebooks |
|
|
|
|
|
self.sample_minimization_weight = sample_minimization_weight |
|
self.batch_maximization_weight = batch_maximization_weight |
|
|
|
|
|
self.token_factorization = token_factorization |
|
if not self.token_factorization: |
|
self.register_buffer( |
|
"mask", 2 ** torch.arange(self.codebook_dim), persistent=False |
|
) |
|
else: |
|
self.factorized_bits = factorized_bits |
|
self.register_buffer( |
|
"pre_mask", 2 ** torch.arange(factorized_bits[0]), persistent=False |
|
) |
|
self.register_buffer( |
|
"post_mask", 2 ** torch.arange(factorized_bits[1]), persistent=False |
|
) |
|
|
|
self.register_buffer("zero", torch.tensor(0.0), persistent=False) |
|
|
|
|
|
|
|
all_codes = torch.arange(codebook_size) |
|
bits = self.indices_to_bits(all_codes) |
|
codebook = bits * 2.0 - 1.0 |
|
|
|
self.register_buffer("codebook", codebook, persistent=False) |
|
|
|
@property |
|
def dtype(self): |
|
return self.codebook.dtype |
|
|
|
def indices_to_bits(self, x): |
|
""" |
|
x: long tensor of indices |
|
|
|
returns big endian bits |
|
""" |
|
mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) |
|
|
|
x = (x.unsqueeze(-1) & mask) != 0 |
|
return x |
|
|
|
def get_codebook_entry(self, x, bhwc, order): |
|
if self.token_factorization: |
|
if order == "pre": |
|
mask = 2 ** torch.arange( |
|
self.factorized_bits[0], device=x.device, dtype=torch.long |
|
) |
|
else: |
|
mask = 2 ** torch.arange( |
|
self.factorized_bits[1], device=x.device, dtype=torch.long |
|
) |
|
else: |
|
mask = 2 ** torch.arange( |
|
self.codebook_dim, device=x.device, dtype=torch.long |
|
) |
|
|
|
x = (x.unsqueeze(-1) & mask) != 0 |
|
x = x * 2.0 - 1.0 |
|
|
|
b, h, w, c = bhwc |
|
x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c) |
|
x = rearrange(x, "b h w c -> b c h w") |
|
return x |
|
|
|
def bits_to_indices(self, bits): |
|
""" |
|
bits: bool tensor of big endian bits, where the last dimension is the bit dimension |
|
|
|
returns indices, which are long integers from 0 to self.codebook_size |
|
""" |
|
assert bits.shape[-1] == self.codebook_dim |
|
indices = 2 ** torch.arange( |
|
0, |
|
self.codebook_dim, |
|
1, |
|
dtype=torch.long, |
|
device=bits.device, |
|
) |
|
return (bits * indices).sum(-1) |
|
|
|
def decode(self, x): |
|
""" |
|
x: ... NH |
|
where NH is number of codebook heads |
|
A longtensor of codebook indices, containing values from |
|
0 to self.codebook_size |
|
""" |
|
x = self.indices_to_bits(x) |
|
|
|
x = x.to(self.dtype) |
|
|
|
x = x * 2 - 1 |
|
x = rearrange(x, "... NC Z-> ... (NC Z)") |
|
return x |
|
|
|
def forward( |
|
self, |
|
x, |
|
inv_temperature=100.0, |
|
return_loss_breakdown=False, |
|
mask=None, |
|
return_loss=True, |
|
): |
|
""" |
|
einstein notation |
|
b - batch |
|
n - sequence (or flattened spatial dimensions) |
|
d - feature dimension, which is also log2(codebook size) |
|
c - number of codebook dim |
|
""" |
|
|
|
|
|
x = rearrange(x, "b d ... -> b ... d") |
|
x, ps = pack_one(x, "b * d") |
|
|
|
|
|
x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks) |
|
|
|
codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype) |
|
quantized = torch.where( |
|
x > 0, codebook_value, -codebook_value |
|
) |
|
|
|
|
|
if self.token_factorization: |
|
indices_pre = reduce( |
|
(quantized[..., : self.factorized_bits[0]] > 0).int() |
|
* self.pre_mask.int(), |
|
"b n c d -> b n c", |
|
"sum", |
|
) |
|
indices_post = reduce( |
|
(quantized[..., self.factorized_bits[0] :] > 0).int() |
|
* self.post_mask.int(), |
|
"b n c d -> b n c", |
|
"sum", |
|
) |
|
else: |
|
|
|
indices = reduce( |
|
(quantized > 0).int() * self.mask.int(), "b n c d -> b n c", "sum" |
|
) |
|
|
|
|
|
|
|
|
|
if self.training and return_loss: |
|
logits = 2 * einsum("... i d, j d -> ... i j", x, self.codebook) |
|
|
|
|
|
|
|
|
|
per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss( |
|
logits=logits, |
|
sample_minimization_weight=self.sample_minimization_weight, |
|
batch_maximization_weight=self.batch_maximization_weight, |
|
) |
|
|
|
avg_probs = self.zero |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
per_sample_entropy = codebook_entropy = self.zero |
|
|
|
entropy_aux_loss = self.zero |
|
avg_probs = self.zero |
|
|
|
|
|
|
|
if self.training: |
|
commit_loss = F.mse_loss(x, quantized.detach(), reduction="none") |
|
|
|
if exists(mask): |
|
commit_loss = commit_loss[mask] |
|
|
|
commit_loss = commit_loss.mean() |
|
else: |
|
commit_loss = self.zero |
|
|
|
|
|
|
|
quantized = x + (quantized - x).detach() |
|
|
|
|
|
|
|
quantized = rearrange(quantized, "b n c d -> b n (c d)") |
|
|
|
|
|
|
|
quantized = unpack_one(quantized, ps, "b * d") |
|
quantized = rearrange(quantized, "b ... d -> b d ...") |
|
|
|
if self.token_factorization: |
|
indices_pre = unpack_one(indices_pre, ps, "b * c") |
|
indices_post = unpack_one(indices_post, ps, "b * c") |
|
indices_pre = indices_pre.flatten() |
|
indices_post = indices_post.flatten() |
|
indices = (indices_pre, indices_post) |
|
else: |
|
|
|
indices = unpack_one(indices, ps, "b * c") |
|
|
|
indices = indices.flatten() |
|
|
|
ret = (quantized, entropy_aux_loss, indices) |
|
|
|
if not return_loss_breakdown: |
|
return ret |
|
|
|
return ret, LossBreakdown( |
|
per_sample_entropy, codebook_entropy, commit_loss, avg_probs |
|
) |
|
|