import torch from torch import nn from torch import einsum from torch.nn import functional as F class VectorQuantize(nn.Module): def __init__(self, hidden_dim, embedding_dim, n_embed, commitment_cost=1): super().__init__() self.hidden_dim = hidden_dim self.embedding_dim = embedding_dim self.n_embed = n_embed self.commitment_cost = commitment_cost self.proj = nn.Conv2d(hidden_dim, embedding_dim, 1) self.embed = nn.Embedding(n_embed, embedding_dim) self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed) def forward(self, z): B, C, H, W = z.shape z_e = self.proj(z) z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, C) flatten = z_e.reshape(-1, self.embedding_dim) dist = ( flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ self.embed.weight.t() + self.embed.weight.pow(2).sum(1, keepdim=True).t() ) _, embed_ind = (-dist).max(1) embed_ind = embed_ind.view(B, H, W) z_q = self.embed_code(embed_ind) diff = self.commitment_cost * (z_q.detach() - z_e).pow(2).mean() \ + (z_q - z_e.detach()).pow(2).mean() z_q = z_e + (z_q - z_e).detach() return z_q, diff, embed_ind def embed_code(self, embed_id): return F.embedding(embed_id, self.embed.weight) class VectorQuantizeEMA(nn.Module): def __init__(self, hidden_dim, embedding_dim, n_embed, commitment_cost=1, decay=0.99, eps=1e-5, pre_proj=True, training_loc=True): super().__init__() self.hidden_dim = hidden_dim self.embedding_dim = embedding_dim self.n_embed = n_embed self.commitment_cost = commitment_cost self.training_loc = training_loc self.pre_proj = pre_proj if self.pre_proj: self.proj = nn.Conv2d(hidden_dim, embedding_dim, 1) self.embed = nn.Embedding(n_embed, embedding_dim) self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed) self.register_buffer("cluster_size", torch.zeros(n_embed)) self.register_buffer("embed_avg", self.embed.weight.data.clone()) self.decay = decay self.eps = eps def forward(self, z): B, C, H, W = z.shape if self.pre_proj: z_e = self.proj(z) else: z_e = z z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, C) flatten = z_e.reshape(-1, self.embedding_dim) dist = ( flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ self.embed.weight.t() + self.embed.weight.pow(2).sum(1, keepdim=True).t() ) _, embed_ind = (-dist).max(1) embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) embed_ind = embed_ind.view(B, H, W) z_q = self.embed_code(embed_ind) diff = self.commitment_cost * (z_q.detach() - z_e).pow(2).mean() z_q = z_e + (z_q - z_e).detach() return z_q, diff, embed_ind def embed_code(self, embed_id): return F.embedding(embed_id, self.embed.weight) class GumbelQuantize(nn.Module): def __init__(self, hidden_dim, embedding_dim, n_embed, commitment_cost=1, straight_through=True, kl_weight=5e-4, temp_init=1., eps=1e-5): super().__init__() self.hidden_dim = hidden_dim self.embedding_dim = embedding_dim self.n_embed = n_embed self.commitment_cost = commitment_cost self.kl_weight = kl_weight self.temperature = temp_init self.eps = eps self.proj = nn.Conv2d(hidden_dim, n_embed, 1) self.embed = nn.Embedding(n_embed, embedding_dim) self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed) self.straight_through = straight_through def forward(self, z, temp=None): hard = self.straight_through if self.training else True temp = self.temperature if temp is None else temp B, C, H, W = z.shape z_e = self.proj(z) soft_one_hot = F.gumbel_softmax(z_e, tau=temp, dim=1, hard=hard) z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) qy = F.softmax(z_e, dim=1) diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + self.eps), dim=1).mean() embed_ind = soft_one_hot.argmax(dim=1) z_q = z_q.permute(0, 2, 3, 1) return z_q, diff, embed_ind def embed_code(self, embed_id): return F.embedding(embed_id, self.embed.weight)