|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from gist1.gpt import GPT |
|
from gist1.vqvae import VQVAE |
|
|
|
from utils.misc import load_params |
|
|
|
|
|
class VQVAETransformer(nn.Module): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.vqvae = self.load_vqvae(args) |
|
self.transformer = self.load_transformer(args) |
|
|
|
self.pkeep = args['pkeep'] |
|
self.vqvae_vocab_size = args['vocab_size'] |
|
self.loc_vocab_size = args['loc_vocab_size'] |
|
self.block_size = args['block_size'] |
|
|
|
def load_vqvae(self, args): |
|
|
|
|
|
|
|
|
|
|
|
num_hiddens = args['vqvae_num_hiddens'] |
|
num_residual_layers = args['vqvae_num_residual_layers'] |
|
num_residual_hiddens = args['vqvae_num_residual_hiddens'] |
|
num_embeddings = args['latent_dim'] |
|
latent_dim = args['vqvae_latent_dim'] |
|
commitment_cost = args['vqvae_commitment_cost'] |
|
decay = args['vqvae_decay'] |
|
model = VQVAE(num_hiddens, num_residual_layers, num_residual_hiddens, |
|
num_embeddings, latent_dim, commitment_cost, |
|
decay) |
|
|
|
|
|
|
|
|
|
args['vocab_size'] = num_embeddings |
|
|
|
return model |
|
|
|
def load_transformer(self, args): |
|
|
|
|
|
latent_dim = args['latent_dim'] |
|
heads = args['heads'] |
|
N = args['N'] |
|
block_size = args['block_size'] |
|
vocab_size = args['vocab_size'] + args['loc_vocab_size'] |
|
model = GPT(vocab_size, latent_dim, N, heads, block_size) |
|
return model |
|
|
|
@torch.no_grad() |
|
def encode_to_z(self, x): |
|
quantized, indices = self.vqvae.encode(x) |
|
indices = indices.view(quantized.shape[0], -1) |
|
return quantized, indices |
|
|
|
@ torch.no_grad() |
|
def z_to_isovist(self, indices): |
|
indices[indices > self.vqvae_vocab_size-1] = self.vqvae_vocab_size-1 |
|
embedding_dim = self.vqvae.vq.embedding_dim |
|
ix_to_vectors = self.vqvae.vq.embedding(indices).reshape(indices.shape[0], -1, embedding_dim) |
|
ix_to_vectors = ix_to_vectors.permute(0, 2, 1) |
|
isovist = self.vqvae.decode(ix_to_vectors) |
|
return isovist |
|
|
|
def loc_to_indices(self, x): |
|
starting_index = self.vqvae_vocab_size |
|
indices = x.long() + starting_index |
|
return indices |
|
|
|
def indices_to_loc(self, indices): |
|
starting_index = self.vqvae_vocab_size |
|
locs = indices - starting_index |
|
locs[locs < 0] = 0 |
|
locs[locs > (self.loc_vocab_size-1)] = self.loc_vocab_size-1 |
|
return locs |
|
|
|
def seq_encode(self, locs, isovists): |
|
|
|
indices_seq = [] |
|
|
|
for i in range(isovists.shape[1]): |
|
loc = locs[:, i].unsqueeze(1) |
|
indices_seq.append(self.loc_to_indices(loc)) |
|
isovist = isovists[:, i, :].unsqueeze(1) |
|
_, indices = self.encode_to_z(isovist) |
|
indices_seq.append(indices) |
|
indices = torch.cat(indices_seq, dim=1) |
|
return indices |
|
|
|
|
|
def forward(self, indices): |
|
device = indices.device |
|
|
|
|
|
|
|
if self.training and self.pkeep < 1.0: |
|
mask = torch.bernoulli(self.pkeep*torch.ones(indices.shape, device=device)) |
|
mask = mask.round().to(dtype=torch.int64) |
|
random_indices = torch.randint_like(indices, self.vqvae_vocab_size) |
|
new_indices = mask*indices + (1-mask)*random_indices |
|
else: |
|
new_indices = indices |
|
|
|
|
|
target = indices[:, 1:] |
|
|
|
|
|
logits = self.transformer(new_indices[:, :-1]) |
|
|
|
|
|
|
|
return logits, target |
|
|
|
|
|
def top_k_logits(self, logits, k): |
|
v, ix = torch.topk(logits, k) |
|
out = logits.clone() |
|
out[out < v[..., [-1]]] = -float("inf") |
|
return out |
|
|
|
|
|
|
|
def sample(self, x, steps, temp=1.0, top_k=100, seed=None, step_size=17, zeroing=False): |
|
device = x.device |
|
is_train = False |
|
if self.transformer.training == True: |
|
is_train = True |
|
self.transformer.eval() |
|
block_size = self.block_size |
|
generator = None |
|
if seed is not None: |
|
generator = torch.Generator(device).manual_seed(seed) |
|
for k in range(steps): |
|
if x.size(1) < block_size: |
|
x_cond = x |
|
else: |
|
remain = step_size - (x.size(1) % step_size) |
|
x_cond = x[:, -(block_size-remain):] |
|
if zeroing: |
|
x_cond = x_cond.clone() |
|
x_cond[:, 0] = self.vqvae_vocab_size |
|
logits = self.transformer(x_cond) |
|
logits = logits[:, -1, :] / temp |
|
|
|
if top_k is not None: |
|
logits = self.top_k_logits(logits, top_k) |
|
|
|
probs = F.softmax(logits, dim = -1) |
|
|
|
ix = torch.multinomial(probs, num_samples=1, generator=generator) |
|
|
|
x = torch.cat((x, ix), dim=1) |
|
|
|
if is_train == True: |
|
self.transformer.train() |
|
|
|
return x |
|
|
|
|
|
def get_loc(self, ploc, dir): |
|
if dir == 0: |
|
loc = ploc |
|
elif dir == 1: |
|
loc = (ploc[0]+1, ploc[1]) |
|
elif dir == 2: |
|
loc = (ploc[0]+1, ploc[1]+1) |
|
elif dir == 3: |
|
loc = (ploc[0], ploc[1]+1) |
|
elif dir == 4: |
|
loc = (ploc[0]-1, ploc[1]+1) |
|
elif dir == 5: |
|
loc = (ploc[0]-1, ploc[1]) |
|
elif dir == 6: |
|
loc = (ploc[0]-1, ploc[1]-1) |
|
elif dir == 7: |
|
loc = (ploc[0], ploc[1]-1) |
|
elif dir == 8: |
|
loc = (ploc[0]+1, ploc[1]-1) |
|
else: |
|
raise NameError('Direction unknown') |
|
return loc |
|
|
|
|
|
def init_loc(self, x, step_size): |
|
device = x.device |
|
loc_dict = {} |
|
loc = None |
|
cached_loc = None |
|
if x.shape[1] > 1: |
|
steps = x.shape[1] -1 |
|
for k in range(steps): |
|
if k % step_size == 0: |
|
dir = x[:,k].detach().item() - self.vqvae_vocab_size |
|
if dir == 0: |
|
loc = (0, 0) |
|
else: |
|
loc = self.get_loc(loc, dir) |
|
loc_dict[loc] = torch.empty(1,0).long().to(device) |
|
cached_loc = loc |
|
else: |
|
ix = x[:,[k]] |
|
loc_dict[cached_loc] = torch.cat((loc_dict[cached_loc], ix), dim = 1) |
|
return loc_dict, loc |
|
|
|
def sample_memorized(self, x, steps, temp=1.0, top_k=100, seed=None, step_size=17, zeroing=False): |
|
device = x.device |
|
loc_dict, loc = self.init_loc(x, step_size) |
|
is_train = False |
|
if self.transformer.training == True: |
|
is_train = True |
|
self.transformer.eval() |
|
block_size = self.block_size |
|
generator = None |
|
if seed is not None: |
|
generator = torch.Generator(device).manual_seed(seed) |
|
is_visited = False |
|
cache_counter = 0 |
|
|
|
for k in range(steps): |
|
|
|
if k % step_size == 0: |
|
dir = x[:,-1].detach().item() - self.vqvae_vocab_size |
|
if dir == 0: |
|
is_visited = False |
|
loc = (0, 0) |
|
loc_dict[loc] = torch.empty(1,0).long().to(device) |
|
else: |
|
loc = self.get_loc(loc, dir) |
|
if loc in loc_dict: |
|
is_visited = True |
|
cache_counter = 0 |
|
else: |
|
is_visited = False |
|
loc_dict[loc] = torch.empty(1,0).long().to(device) |
|
|
|
|
|
if x.size(1) < block_size: |
|
x_cond = x |
|
else: |
|
remain = step_size - (x.size(1) % step_size) |
|
x_cond = x[:, -(block_size-remain):] |
|
if zeroing: |
|
x_cond = x_cond.clone() |
|
x_cond[:, 0] = self.vqvae_vocab_size |
|
|
|
if is_visited == False: |
|
logits = self.transformer(x_cond) |
|
logits = logits[:, -1, :] / temp |
|
|
|
if top_k is not None: |
|
logits = self.top_k_logits(logits, top_k) |
|
|
|
probs = F.softmax(logits, dim = -1) |
|
ix = torch.multinomial(probs, num_samples=1, generator=generator) |
|
|
|
loc_dict[loc] = torch.cat((loc_dict[loc], ix), dim = 1) |
|
else: |
|
if cache_counter == 15: |
|
is_visited = False |
|
ix = loc_dict[loc][:,[cache_counter]] |
|
|
|
cache_counter += 1 |
|
|
|
x = torch.cat((x, ix), dim=1) |
|
|
|
|
|
if is_train == True: |
|
self.transformer.train() |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|