gist_demo / gist1 /vqvae.py
Mikhael Johanes
push files
d491737
# reference https://github.com/zalandoresearch/pytorch-vq-vae
import torch
import torch.nn as nn
import torch.nn.functional as F
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
self.commitment_cost = commitment_cost
def forward(self, inputs):
# convert input from BCW -> BWC
inputs = inputs.permut(0, 2, 1).contiguous()
input_shape = inputs.shape
# flatten input
flat_input = inputs.view(-1, self.embedding_dim)
# calculate distances
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.matmul(flat_input, self.embedding.weight.t()))
# encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1)
# quantize and unflatten
quantized = torch.matmul(encodings, self.embedding.weight).view(input_shape)
# loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
q_latent_loss = F.mse_loss(quantized, input.detach())
loss = q_latent_loss + self.commitment_cost * e_latent_loss
quantized = inputs + (quantized - inputs).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
# convert quantized from BWC -> BCW
return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, encodings
class VectorQuantizerEMA(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
self.embedding.weight.data.normal_()
self.commitment_cost = commitment_cost
self.register_buffer('ema_cluster_size', torch.zeros(num_embeddings))
self.ema_w = nn.Parameter(torch.Tensor(num_embeddings, self.embedding_dim))
self.ema_w.data.normal_()
self.decay = decay
self.epsilon = epsilon
def forward(self, inputs):
#convert inputs from BCW -> BWC
inputs = inputs.permute(0, 2, 1).contiguous()
input_shape = inputs.shape
# flatten input
flat_input = inputs.view(-1, self.embedding_dim)
# calculate distances
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.matmul(flat_input, self.embedding.weight.t()))
# encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1)
# quantize and unflatten
quantized = torch.matmul(encodings, self.embedding.weight).view(input_shape)
# use EMA to update the embedding vectors
if self.training:
self.ema_cluster_size = self.ema_cluster_size * self.decay + (1 - self.decay) * torch.sum(encodings, 0)
# laplace smoothing of the cluster size
n = torch.sum(self.ema_cluster_size)
self.ema_cluster_size = self.ema_cluster_size + self.epsilon / (n + self.num_embeddings * self.epsilon * n)
dw = torch.matmul(encodings.t(), flat_input)
self.ema_w = nn.Parameter(self.ema_w * self.decay + (1 - self.decay) * dw)
self.embedding.weight = nn.Parameter(self.ema_w / self.ema_cluster_size.unsqueeze(1))
# loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
loss = self.commitment_cost * e_latent_loss
# straight trough estimator
quantized = inputs + (quantized - inputs).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
# convert quantized from BWC -> BCW
return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, encoding_indices
class Residual(nn.Module):
def __init__(self, in_channels, num_hiddnes, num_residual_hiddens):
super().__init__()
self.block = nn.Sequential( nn.ReLU(inplace=True),
nn.Conv1d( in_channels=in_channels,
out_channels=num_residual_hiddens,
kernel_size=3, stride=1, padding=1, bias=False, padding_mode='circular'),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=num_residual_hiddens,
out_channels=num_hiddnes,
kernel_size=1, stride=1, bias=False)
)
def forward(self, x):
return x + self.block(x)
class ResidualStack(nn.Module):
def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
super().__init__()
self.num_residual_layers = num_residual_layers
self.layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
for _ in range(self.num_residual_layers)])
def forward(self, x):
for i in range(self.num_residual_layers):
x = self.layers[i](x)
return F.relu(x)
class Encoder(nn.Module):
def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
super().__init__()
# 256 -> 128
self.conv_1 = nn.Conv1d(in_channels=in_channels,
out_channels=num_hiddens//2,
kernel_size=4,
stride=2, padding=1, padding_mode='circular')
# 128 -> 64
self.conv_2 = nn.Conv1d(in_channels=num_hiddens//2,
out_channels=num_hiddens,
kernel_size=4,
stride=2, padding=1, padding_mode='circular')
# 64 -> 32
self.conv_3 = nn.Conv1d(in_channels=num_hiddens,
out_channels=num_hiddens,
kernel_size=4,
stride=2, padding=1, padding_mode='circular')
# 32 -> 16
self.conv_4 = nn.Conv1d(in_channels=num_hiddens,
out_channels=num_hiddens,
kernel_size=4,
stride=2, padding=1, padding_mode='circular')
self.conv_final = nn.Conv1d(in_channels=num_hiddens,
out_channels=num_hiddens,
kernel_size=3,
stride=1, padding=1, padding_mode='circular')
self.residual_stack = ResidualStack(in_channels=num_hiddens,
num_hiddens=num_hiddens,
num_residual_hiddens=num_residual_hiddens,
num_residual_layers=num_residual_layers)
def forward(self, inputs):
x = self.conv_1(inputs)
x = F.relu(x)
x = self.conv_2(x)
x = F.relu(x)
x = self.conv_3(x)
x = F.relu(x)
x = self.conv_4(x)
x = F.relu(x)
x = self.conv_final(x)
x = self.residual_stack(x)
return x
class Decoder(nn.Module):
def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
super().__init__()
self.conv_init = nn.Conv1d( in_channels=in_channels,
out_channels=num_hiddens,
kernel_size=3,
stride=1, padding=1)
self.residual_stack = ResidualStack(in_channels=num_hiddens,
num_hiddens=num_hiddens,
num_residual_layers=num_residual_layers,
num_residual_hiddens=num_residual_hiddens)
# 16 -> 32
self.conv_trans_0 = nn.ConvTranspose1d( in_channels=num_hiddens,
out_channels=num_hiddens,
kernel_size=4,
stride=2, padding=1)
# 32 -> 64
self.conv_trans_1 = nn.ConvTranspose1d( in_channels=num_hiddens,
out_channels=num_hiddens,
kernel_size=4,
stride=2, padding=1)
# 64 -> 128
self.conv_trans_2 = nn.ConvTranspose1d( in_channels=num_hiddens,
out_channels=num_hiddens//2,
kernel_size=4,
stride=2, padding=1)
# 128 -> 256
self.conv_trans_3 = nn.ConvTranspose1d( in_channels=num_hiddens//2,
out_channels=1,
kernel_size=4,
stride=2, padding=1)
def forward(self, inputs):
x = self.conv_init(inputs)
x = self.residual_stack(x)
x = self.conv_trans_0(x)
x = F.relu(x)
x = self.conv_trans_1(x)
x = F.relu(x)
x = self.conv_trans_2(x)
x = F.relu(x)
return self.conv_trans_3(x)
class VQVAE(nn.Module):
def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings,
embedding_dim, commitment_cost, decay=0):
super().__init__()
self.encoder = Encoder( 1, num_hiddens,
num_residual_layers,
num_residual_hiddens)
self.pre_vq_conv = nn.Conv1d( in_channels=num_hiddens,
out_channels=embedding_dim,
kernel_size=1,
stride=1)
if decay > 0.0:
self.vq = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost, decay)
else:
self.vq = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
self.decoder = Decoder( embedding_dim,
num_hiddens,
num_residual_layers,
num_residual_hiddens)
def encode(self, x):
z = self.encoder(x)
z = self.pre_vq_conv(z)
_, quantized, _, encoding_indices = self.vq(z)
return quantized, encoding_indices
def decode(self, x):
return self.decoder(x)
def forward(self, x):
z = self.encoder(x)
z = self.pre_vq_conv(z)
loss, quantized, perplexity, _ = self.vq(z)
x_recon = self.decoder(quantized)
return loss, x_recon, perplexity