|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class VectorQuantizer(nn.Module): |
|
def __init__(self, num_embeddings, embedding_dim, commitment_cost): |
|
super().__init__() |
|
|
|
self.embedding_dim = embedding_dim |
|
self.num_embeddings = num_embeddings |
|
|
|
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) |
|
self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings) |
|
self.commitment_cost = commitment_cost |
|
|
|
def forward(self, inputs): |
|
|
|
inputs = inputs.permut(0, 2, 1).contiguous() |
|
input_shape = inputs.shape |
|
|
|
|
|
flat_input = inputs.view(-1, self.embedding_dim) |
|
|
|
|
|
distances = (torch.sum(flat_input**2, dim=1, keepdim=True) |
|
+ torch.sum(self.embedding.weight**2, dim=1) |
|
- 2 * torch.matmul(flat_input, self.embedding.weight.t())) |
|
|
|
|
|
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) |
|
encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device) |
|
encodings.scatter_(1, encoding_indices, 1) |
|
|
|
|
|
quantized = torch.matmul(encodings, self.embedding.weight).view(input_shape) |
|
|
|
|
|
e_latent_loss = F.mse_loss(quantized.detach(), inputs) |
|
q_latent_loss = F.mse_loss(quantized, input.detach()) |
|
loss = q_latent_loss + self.commitment_cost * e_latent_loss |
|
|
|
quantized = inputs + (quantized - inputs).detach() |
|
avg_probs = torch.mean(encodings, dim=0) |
|
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) |
|
|
|
|
|
return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, encodings |
|
|
|
|
|
class VectorQuantizerEMA(nn.Module): |
|
def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5): |
|
super().__init__() |
|
|
|
self.embedding_dim = embedding_dim |
|
self.num_embeddings = num_embeddings |
|
|
|
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) |
|
self.embedding.weight.data.normal_() |
|
self.commitment_cost = commitment_cost |
|
|
|
self.register_buffer('ema_cluster_size', torch.zeros(num_embeddings)) |
|
self.ema_w = nn.Parameter(torch.Tensor(num_embeddings, self.embedding_dim)) |
|
self.ema_w.data.normal_() |
|
|
|
self.decay = decay |
|
self.epsilon = epsilon |
|
|
|
def forward(self, inputs): |
|
|
|
inputs = inputs.permute(0, 2, 1).contiguous() |
|
input_shape = inputs.shape |
|
|
|
|
|
flat_input = inputs.view(-1, self.embedding_dim) |
|
|
|
|
|
distances = (torch.sum(flat_input**2, dim=1, keepdim=True) |
|
+ torch.sum(self.embedding.weight**2, dim=1) |
|
- 2 * torch.matmul(flat_input, self.embedding.weight.t())) |
|
|
|
|
|
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) |
|
encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device) |
|
encodings.scatter_(1, encoding_indices, 1) |
|
|
|
|
|
quantized = torch.matmul(encodings, self.embedding.weight).view(input_shape) |
|
|
|
|
|
if self.training: |
|
self.ema_cluster_size = self.ema_cluster_size * self.decay + (1 - self.decay) * torch.sum(encodings, 0) |
|
|
|
|
|
n = torch.sum(self.ema_cluster_size) |
|
self.ema_cluster_size = self.ema_cluster_size + self.epsilon / (n + self.num_embeddings * self.epsilon * n) |
|
dw = torch.matmul(encodings.t(), flat_input) |
|
self.ema_w = nn.Parameter(self.ema_w * self.decay + (1 - self.decay) * dw) |
|
|
|
self.embedding.weight = nn.Parameter(self.ema_w / self.ema_cluster_size.unsqueeze(1)) |
|
|
|
|
|
e_latent_loss = F.mse_loss(quantized.detach(), inputs) |
|
loss = self.commitment_cost * e_latent_loss |
|
|
|
|
|
quantized = inputs + (quantized - inputs).detach() |
|
avg_probs = torch.mean(encodings, dim=0) |
|
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) |
|
|
|
|
|
return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, encoding_indices |
|
|
|
|
|
class Residual(nn.Module): |
|
def __init__(self, in_channels, num_hiddnes, num_residual_hiddens): |
|
super().__init__() |
|
self.block = nn.Sequential( nn.ReLU(inplace=True), |
|
nn.Conv1d( in_channels=in_channels, |
|
out_channels=num_residual_hiddens, |
|
kernel_size=3, stride=1, padding=1, bias=False, padding_mode='circular'), |
|
nn.ReLU(inplace=True), |
|
nn.Conv1d(in_channels=num_residual_hiddens, |
|
out_channels=num_hiddnes, |
|
kernel_size=1, stride=1, bias=False) |
|
) |
|
def forward(self, x): |
|
return x + self.block(x) |
|
|
|
|
|
class ResidualStack(nn.Module): |
|
def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): |
|
super().__init__() |
|
self.num_residual_layers = num_residual_layers |
|
self.layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens) |
|
for _ in range(self.num_residual_layers)]) |
|
|
|
def forward(self, x): |
|
for i in range(self.num_residual_layers): |
|
x = self.layers[i](x) |
|
return F.relu(x) |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): |
|
super().__init__() |
|
|
|
self.conv_1 = nn.Conv1d(in_channels=in_channels, |
|
out_channels=num_hiddens//2, |
|
kernel_size=4, |
|
stride=2, padding=1, padding_mode='circular') |
|
|
|
self.conv_2 = nn.Conv1d(in_channels=num_hiddens//2, |
|
out_channels=num_hiddens, |
|
kernel_size=4, |
|
stride=2, padding=1, padding_mode='circular') |
|
|
|
self.conv_3 = nn.Conv1d(in_channels=num_hiddens, |
|
out_channels=num_hiddens, |
|
kernel_size=4, |
|
stride=2, padding=1, padding_mode='circular') |
|
|
|
self.conv_4 = nn.Conv1d(in_channels=num_hiddens, |
|
out_channels=num_hiddens, |
|
kernel_size=4, |
|
stride=2, padding=1, padding_mode='circular') |
|
self.conv_final = nn.Conv1d(in_channels=num_hiddens, |
|
out_channels=num_hiddens, |
|
kernel_size=3, |
|
stride=1, padding=1, padding_mode='circular') |
|
self.residual_stack = ResidualStack(in_channels=num_hiddens, |
|
num_hiddens=num_hiddens, |
|
num_residual_hiddens=num_residual_hiddens, |
|
num_residual_layers=num_residual_layers) |
|
|
|
def forward(self, inputs): |
|
x = self.conv_1(inputs) |
|
x = F.relu(x) |
|
|
|
x = self.conv_2(x) |
|
x = F.relu(x) |
|
|
|
x = self.conv_3(x) |
|
x = F.relu(x) |
|
|
|
x = self.conv_4(x) |
|
x = F.relu(x) |
|
|
|
x = self.conv_final(x) |
|
x = self.residual_stack(x) |
|
|
|
return x |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): |
|
super().__init__() |
|
self.conv_init = nn.Conv1d( in_channels=in_channels, |
|
out_channels=num_hiddens, |
|
kernel_size=3, |
|
stride=1, padding=1) |
|
self.residual_stack = ResidualStack(in_channels=num_hiddens, |
|
num_hiddens=num_hiddens, |
|
num_residual_layers=num_residual_layers, |
|
num_residual_hiddens=num_residual_hiddens) |
|
|
|
|
|
self.conv_trans_0 = nn.ConvTranspose1d( in_channels=num_hiddens, |
|
out_channels=num_hiddens, |
|
kernel_size=4, |
|
stride=2, padding=1) |
|
|
|
|
|
self.conv_trans_1 = nn.ConvTranspose1d( in_channels=num_hiddens, |
|
out_channels=num_hiddens, |
|
kernel_size=4, |
|
stride=2, padding=1) |
|
|
|
self.conv_trans_2 = nn.ConvTranspose1d( in_channels=num_hiddens, |
|
out_channels=num_hiddens//2, |
|
kernel_size=4, |
|
stride=2, padding=1) |
|
|
|
self.conv_trans_3 = nn.ConvTranspose1d( in_channels=num_hiddens//2, |
|
out_channels=1, |
|
kernel_size=4, |
|
stride=2, padding=1) |
|
|
|
def forward(self, inputs): |
|
x = self.conv_init(inputs) |
|
|
|
x = self.residual_stack(x) |
|
|
|
x = self.conv_trans_0(x) |
|
x = F.relu(x) |
|
|
|
x = self.conv_trans_1(x) |
|
x = F.relu(x) |
|
|
|
x = self.conv_trans_2(x) |
|
x = F.relu(x) |
|
|
|
return self.conv_trans_3(x) |
|
|
|
|
|
class VQVAE(nn.Module): |
|
def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings, |
|
embedding_dim, commitment_cost, decay=0): |
|
super().__init__() |
|
self.encoder = Encoder( 1, num_hiddens, |
|
num_residual_layers, |
|
num_residual_hiddens) |
|
self.pre_vq_conv = nn.Conv1d( in_channels=num_hiddens, |
|
out_channels=embedding_dim, |
|
kernel_size=1, |
|
stride=1) |
|
|
|
if decay > 0.0: |
|
self.vq = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost, decay) |
|
else: |
|
self.vq = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost) |
|
|
|
self.decoder = Decoder( embedding_dim, |
|
num_hiddens, |
|
num_residual_layers, |
|
num_residual_hiddens) |
|
|
|
def encode(self, x): |
|
z = self.encoder(x) |
|
z = self.pre_vq_conv(z) |
|
_, quantized, _, encoding_indices = self.vq(z) |
|
|
|
return quantized, encoding_indices |
|
|
|
def decode(self, x): |
|
return self.decoder(x) |
|
|
|
def forward(self, x): |
|
z = self.encoder(x) |
|
z = self.pre_vq_conv(z) |
|
|
|
loss, quantized, perplexity, _ = self.vq(z) |
|
x_recon = self.decoder(quantized) |
|
|
|
return loss, x_recon, perplexity |
|
|
|
|
|
|
|
|
|
|