|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
|
|
from .. import models |
|
from ..models import register |
|
|
|
|
|
@register("bottleneck") |
|
class Bottleneck(nn.Module): |
|
def __init__( |
|
self, |
|
bottleneck_dim: int, |
|
input_dim: int, |
|
output_dim: int, |
|
token_nums: int, |
|
regularizer=None, |
|
**kwargs |
|
): |
|
super().__init__() |
|
self.token_nums = token_nums |
|
self.input_dim = input_dim |
|
self.output_dim = output_dim |
|
if bottleneck_dim > 0: |
|
self.bottleneck_dim = bottleneck_dim |
|
else: |
|
assert self.input_dim == self.output_dim, "input_dim and output_dim must be the same when bottleneck_dim is not specified" |
|
self.bottleneck_dim = self.input_dim |
|
|
|
self.project_dim = self.bottleneck_dim |
|
|
|
if self.bottleneck_dim > 0: |
|
self.in_linear = nn.Linear(self.input_dim, self.project_dim) |
|
self.out_linear = nn.Linear(self.bottleneck_dim, self.output_dim) |
|
else: |
|
self.in_linear = self.out_linear = lambda x: x |
|
|
|
regularizer['args']['dim'] = self.bottleneck_dim |
|
regularizer['args']['token_nums'] = self.token_nums |
|
self.regularizer = models.make(regularizer) |
|
|
|
def project_in(self, x): |
|
assert len(x.shape) == 3, "Input shape must be (batch, n_tokens, e_dim)" |
|
z = self.in_linear(x) |
|
return z |
|
|
|
def project_out(self, z_cat): |
|
z = self.out_linear(z_cat) |
|
return z |
|
|
|
def decode(self, bottleneck_rep): |
|
regularized_z = self.regularizer.decode(bottleneck_rep) |
|
return self.project_out(regularized_z) |
|
|
|
def forward(self, x): |
|
z = self.project_in(x) |
|
projected_z = z |
|
regularized_output = self.regularizer(z) |
|
x_hat = self.project_out(regularized_output['regularized_z']) |
|
bottleneck_rep = regularized_output.pop('bottleneck_rep') |
|
return { |
|
'output': x_hat, |
|
'bottleneck_rep': bottleneck_rep, |
|
'projected_z': projected_z, |
|
**regularized_output, |
|
} |
|
|
|
|
|
@register("simvq") |
|
class SimVectorQuantizer(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
codebook_size, |
|
l2_normalized=False, |
|
same_index_shape=True, |
|
stochastic=False, |
|
stochastic_temperature=1.0, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.codebook_size = codebook_size |
|
self.dim = dim |
|
assert isinstance(l2_normalized, bool) |
|
self.l2_normalized = l2_normalized |
|
self.stochastic = stochastic |
|
self.eval_deterministic = False |
|
self.default_stochastic_temperature = stochastic_temperature |
|
|
|
if self.stochastic: |
|
if stochastic_temperature > 0: |
|
self.stochastic_temperature_inv = 1 / stochastic_temperature |
|
else: |
|
self.stochastic_temperature_inv = nn.Parameter(torch.tensor(10.0)) |
|
|
|
|
|
self.embedding = nn.Embedding(self.codebook_size, self.dim) |
|
self.embedding_proj = nn.Linear(self.dim, self.dim) |
|
|
|
self.same_index_shape = same_index_shape |
|
|
|
def set_eval_deterministic(self, deterministic=True): |
|
self.eval_deterministic = deterministic |
|
|
|
def set_stochastic_temperature(self, temperature): |
|
self.stochastic_temperature_inv = 1 / temperature |
|
|
|
@torch.autocast(device_type='cuda', enabled=False) |
|
def get_emb(self): |
|
emb = self.embedding_proj(self.embedding.weight) |
|
if self.l2_normalized: |
|
emb = F.normalize(emb, p=2, dim=-1) |
|
|
|
return emb |
|
|
|
@torch.autocast(device_type='cuda', enabled=False) |
|
def forward(self, z): |
|
emb = self.get_emb() |
|
z = z.to(emb) |
|
|
|
assert len(z.shape) == 3, "Input shape must be (batch, n_tokens, e_dim)" |
|
if self.l2_normalized: |
|
z = F.normalize(z, p=2, dim=-1) |
|
|
|
z_flattened = rearrange(z, 'b n d -> (b n) d') |
|
|
|
if self.stochastic: |
|
|
|
assert self.l2_normalized, "Stochastic sampling requires l2 normalization" |
|
cos_sim = torch.einsum("bd,nd->bn", z_flattened, emb) |
|
probs = F.softmax(cos_sim * self.stochastic_temperature_inv, dim=-1) |
|
if self.eval_deterministic and not self.training: |
|
q_indices = torch.argmax(probs, dim=-1) |
|
else: |
|
q_indices = torch.multinomial(probs, 1).squeeze(-1) |
|
else: |
|
d = ( |
|
torch.sum(z_flattened**2, dim=1, keepdim=True) |
|
+ torch.sum(emb**2, dim=1) |
|
- 2 |
|
* torch.einsum( |
|
"bd,dn->bn", z_flattened, rearrange(emb, "n d -> d n") |
|
) |
|
) |
|
q_indices = torch.argmin(d, dim=1) |
|
|
|
quantized = F.embedding(q_indices, emb, self.embedding.padding_idx, self.embedding.max_norm, |
|
self.embedding.norm_type, self.embedding.scale_grad_by_freq, self.embedding.sparse).view(z.shape) |
|
|
|
|
|
quantized = z + (quantized - z).detach() |
|
|
|
if self.same_index_shape: |
|
q_indices = q_indices.reshape(quantized.shape[0], quantized.shape[1]) |
|
|
|
return_dict = { |
|
'unregularized_z': z, |
|
'emb': emb, |
|
'regularized_z': quantized, |
|
'bottleneck_rep': q_indices |
|
} |
|
return return_dict |
|
|
|
def get_codebook_entry(self, indices, shape=None): |
|
|
|
indices_shape = indices.shape |
|
indices_flatten = rearrange(indices, '... -> (...)') |
|
|
|
|
|
emb = self.get_emb() |
|
z_q = F.embedding(indices_flatten, emb) |
|
|
|
if self.l2_normalized: |
|
z_q = F.normalize(z_q, p=2, dim=-1) |
|
|
|
if shape is not None: |
|
z_q = z_q.reshape(shape) |
|
else: |
|
z_q = z_q.reshape([*indices_shape, self.dim]) |
|
return z_q |
|
|
|
def decode(self, indices): |
|
return self.get_codebook_entry(indices) |