Spaces:
Runtime error
Runtime error
File size: 5,099 Bytes
413d4d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import torch
from torch import nn
from torch import einsum
from torch.nn import functional as F
class VectorQuantize(nn.Module):
def __init__(self,
hidden_dim,
embedding_dim,
n_embed,
commitment_cost=1):
super().__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.commitment_cost = commitment_cost
self.proj = nn.Conv2d(hidden_dim, embedding_dim, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
def forward(self, z):
B, C, H, W = z.shape
z_e = self.proj(z)
z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, C)
flatten = z_e.reshape(-1, self.embedding_dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed.weight.t()
+ self.embed.weight.pow(2).sum(1, keepdim=True).t()
)
_, embed_ind = (-dist).max(1)
embed_ind = embed_ind.view(B, H, W)
z_q = self.embed_code(embed_ind)
diff = self.commitment_cost * (z_q.detach() - z_e).pow(2).mean() \
+ (z_q - z_e.detach()).pow(2).mean()
z_q = z_e + (z_q - z_e).detach()
return z_q, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.weight)
class VectorQuantizeEMA(nn.Module):
def __init__(self,
hidden_dim,
embedding_dim,
n_embed,
commitment_cost=1,
decay=0.99,
eps=1e-5,
pre_proj=True,
training_loc=True):
super().__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.commitment_cost = commitment_cost
self.training_loc = training_loc
self.pre_proj = pre_proj
if self.pre_proj:
self.proj = nn.Conv2d(hidden_dim, embedding_dim, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
self.register_buffer("cluster_size", torch.zeros(n_embed))
self.register_buffer("embed_avg", self.embed.weight.data.clone())
self.decay = decay
self.eps = eps
def forward(self, z):
B, C, H, W = z.shape
if self.pre_proj:
z_e = self.proj(z)
else:
z_e = z
z_e = z_e.permute(0, 2, 3, 1) # (B, H, W, C)
flatten = z_e.reshape(-1, self.embedding_dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed.weight.t()
+ self.embed.weight.pow(2).sum(1, keepdim=True).t()
)
_, embed_ind = (-dist).max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
embed_ind = embed_ind.view(B, H, W)
z_q = self.embed_code(embed_ind)
diff = self.commitment_cost * (z_q.detach() - z_e).pow(2).mean()
z_q = z_e + (z_q - z_e).detach()
return z_q, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.weight)
class GumbelQuantize(nn.Module):
def __init__(self,
hidden_dim,
embedding_dim,
n_embed,
commitment_cost=1,
straight_through=True,
kl_weight=5e-4,
temp_init=1.,
eps=1e-5):
super().__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.commitment_cost = commitment_cost
self.kl_weight = kl_weight
self.temperature = temp_init
self.eps = eps
self.proj = nn.Conv2d(hidden_dim, n_embed, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.embed.weight.data.uniform_(-1. / n_embed, 1. / n_embed)
self.straight_through = straight_through
def forward(self, z, temp=None):
hard = self.straight_through if self.training else True
temp = self.temperature if temp is None else temp
B, C, H, W = z.shape
z_e = self.proj(z)
soft_one_hot = F.gumbel_softmax(z_e, tau=temp, dim=1, hard=hard)
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
qy = F.softmax(z_e, dim=1)
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + self.eps), dim=1).mean()
embed_ind = soft_one_hot.argmax(dim=1)
z_q = z_q.permute(0, 2, 3, 1)
return z_q, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.weight)
|