Mikhael Johanes commited on
Commit
d491737
·
1 Parent(s): d8a5c26

push files

Browse files
app.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import random
4
+
5
+ from gist1.vqvae_gpt import VQVAETransformer
6
+ from utils.misc import load_params
7
+ from utils.isoutil import plot_isovist_sequence_grid
8
+
9
+
10
+ import torch
11
+
12
+
13
+ if torch.cuda.is_available():
14
+ device = torch.device("cuda")
15
+ else:
16
+ device = torch.device("cpu")
17
+
18
+
19
+ model_paths = ["./models/vqvaegpt_1.pth",
20
+ "./models/vqvaegpt_2.pth",
21
+ "./models/vqvaegpt_3.pth"]
22
+ cfg_path = "./models/param.json"
23
+ cfg = load_params(cfg_path)
24
+
25
+
26
+
27
+ @st.cache_resource
28
+ def get_model(index):
29
+ TransformerPath = model_paths[index]
30
+ transformer = VQVAETransformer(cfg)
31
+ transformer.load_state_dict(torch.load(TransformerPath))
32
+ transformer = transformer.to(device)
33
+ transformer.eval()
34
+ return transformer
35
+
36
+
37
+ def split_indices(indices, loc_len=1, isovist_len=16):
38
+ seg_length = loc_len + isovist_len
39
+ batch_size = indices.shape[0]
40
+ splits = indices.reshape(batch_size, -1, seg_length) # BS(L+I)
41
+ ilocs, iisovists = torch.split(splits, [loc_len, isovist_len], dim=2) # BSL , BSI
42
+ return ilocs, iisovists
43
+
44
+ @st.cache_data
45
+ def indices_to_loc(_model, indices):
46
+ indices = torch.tensor(indices).long().view(1,-1).to(device)
47
+ return _model.indices_to_loc(indices).detach().cpu().numpy()
48
+
49
+ @st.cache_data
50
+ def indices_to_isovist(_model, indices):
51
+ indices = torch.tensor(indices).long().view(1,-1).to(device)
52
+ return _model.z_to_isovist(indices).detach().cpu().numpy()
53
+
54
+ def indices_to_loc_isovist(model, indices):
55
+ ilocs, iisovists = split_indices(indices, loc_len=1, isovist_len=16)
56
+ locs = []
57
+ sampled_isovists = []
58
+ for i in range(iisovists.shape[1]):
59
+ # iloc = ilocs[:, i, :]
60
+ # locs.append(model.indices_to_loc(iloc).detach().cpu().numpy()) # S X BL
61
+ # iisovist = iisovists[:, i, :] # BI
62
+ # sampled_isovists.append(model.z_to_isovist(iisovist).detach().cpu().numpy()) # S X BCW
63
+
64
+ iloc = ilocs[:, i, :].squeeze().tolist()
65
+ iisovist = iisovists[:, i, :].squeeze().tolist()
66
+ iisovist = tuple(iisovist)
67
+ locs.append(indices_to_loc(model, iloc))
68
+ sampled_isovists.append(indices_to_isovist(model, iisovist))
69
+ # sampled_isovists.append(code_to_isovist(model, iisovist))
70
+
71
+ locs = np.stack(locs, axis=1)
72
+ sampled_isovists = np.stack(sampled_isovists, axis=1) #BSCW
73
+ return locs, sampled_isovists
74
+
75
+ def plot_isovist(locs, sampled_isovists, lim, alpha, calculate_lim):
76
+ loc = locs[0]
77
+ sampled_isovist = sampled_isovists[0]
78
+ sampled_isovist = np.squeeze(sampled_isovist, axis=1)
79
+ fig = plot_isovist_sequence_grid(loc, sampled_isovist, figsize=(8, 6), center=True, lim=lim, alpha=alpha, calculate_lim=calculate_lim).transpose((1, 2, 0))
80
+ return fig
81
+
82
+ def sample(model, start_indices, top_k=100, seed=0, seq_length=None, zeroing=False, lim=1.5, alpha=0.02, loc_init=False, calculate_lim=False):
83
+ start_indices = start_indices.long().to(device)
84
+ steps = seq_length * (1 + 16) # loc dim + latent
85
+ if loc_init:
86
+ steps -= 1
87
+ sample_indices = model.sample_memorized(start_indices, steps=steps, top_k=top_k, seed=seed, zeroing=zeroing)
88
+ locs, sampled_isovists = indices_to_loc_isovist(model, sample_indices)
89
+ im = plot_isovist(locs, sampled_isovists, lim, alpha, calculate_lim)
90
+ return im, sample_indices
91
+
92
+
93
+ def plot_indices(model, indices, lim=1.5, alpha=0.02, calculate_lim=False):
94
+ locs, sampled_isovists = indices_to_loc_isovist(model, indices)
95
+ im = plot_isovist(locs, sampled_isovists, lim, alpha, calculate_lim)
96
+ return im
97
+
98
+
99
+
100
+ st.subheader("GIsT: Generative Isovist Transformers")
101
+ st.text("Pres [init] to initiate or start over")
102
+
103
+ options =["Base model", "Palladio", "Mies"]
104
+
105
+ if 'model' not in st.session_state:
106
+ st.session_state.model = None
107
+
108
+ if st.session_state.model is not None:
109
+ index = options.index(st.session_state.model)
110
+ else:
111
+ index = 0
112
+
113
+ option = st.selectbox("Select model",(options), index=index)
114
+ st.session_state.model = option
115
+
116
+
117
+ if 'tokens' not in st.session_state:
118
+ st.session_state.tokens = None
119
+
120
+ if 'image' not in st.session_state:
121
+ st.session_state.image = np.ones((600,800,3),dtype=np.uint8) * 240
122
+
123
+ if 'seed' not in st.session_state:
124
+ st.session_state.seed = random.randint(0, 10000000)
125
+
126
+
127
+
128
+ index = options.index(st.session_state.model)
129
+ transformer = get_model(index)
130
+
131
+
132
+ e = 1025
133
+ ne = 1026
134
+ n = 1027
135
+ nw = 1028
136
+ w = 1029
137
+ sw = 1030
138
+ s = 1031
139
+ se = 1032
140
+
141
+ alpha = 0.015
142
+ lim = 2.0
143
+
144
+ init = st.button('init')
145
+
146
+ cont = st.container()
147
+
148
+
149
+
150
+
151
+ rows = []
152
+ for i in range(3):
153
+ rows.append(st.columns(3, gap='small'))
154
+
155
+
156
+
157
+
158
+ upleft = rows[0][0].button('upleft', use_container_width=True)
159
+ up = rows[0][1].button('up', use_container_width=True)
160
+ upright = rows[0][2].button('upright', use_container_width=True)
161
+ left = rows[1][0].button('left', use_container_width=True)
162
+ undo = rows[1][1].button('undo', use_container_width=True)
163
+ right = rows[1][2].button('right', use_container_width=True)
164
+ downleft = rows[2][0].button('downleft', use_container_width=True)
165
+ down = rows[2][1].button('down', use_container_width=True)
166
+ downright = rows[2][2].button('downright', use_container_width=True)
167
+
168
+ st.text("use desktop mode for best experiece in mobile device")
169
+
170
+ seed = st.number_input('seed', 0, 10000000, st.session_state.seed,1)
171
+
172
+
173
+ def gen_next(sample_indices, dir):
174
+ # seed = st.session_state.seed
175
+ sample_indices = torch.concat([sample_indices, torch.tensor([[dir]]).to(device)],dim=1)
176
+ im, sample_indices = sample(transformer, sample_indices, top_k=50, seq_length=1, seed=seed, lim=lim, alpha=alpha, loc_init=True, calculate_lim=True)
177
+ return im, sample_indices
178
+
179
+ def undo_gen(sample_indices):
180
+ sample_indices = sample_indices[:, :-17]
181
+ im = plot_indices(transformer, sample_indices, lim=lim,alpha=alpha, calculate_lim=True)
182
+ return im, sample_indices
183
+
184
+ if init:
185
+ st.session_state.tokens = torch.ones((1, 1)).long().to(device) * 1024
186
+ tokens = st.session_state.tokens
187
+ # seed = st.session_state.seed
188
+ im, sample_indices = sample(transformer, tokens, top_k=50, seq_length=1, seed=seed, lim=lim, alpha=alpha, loc_init=True)
189
+ st.session_state.image = im
190
+ st.session_state.tokens = sample_indices
191
+ st.session_state.lim = 2.0
192
+
193
+ if upleft:
194
+ if st.session_state.tokens is not None:
195
+ st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, nw)
196
+ else:
197
+ st.warning('Please init the generation')
198
+
199
+ if up:
200
+ if st.session_state.tokens is not None:
201
+ st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, n)
202
+ else:
203
+ st.warning('Please init the generation')
204
+
205
+ if upright:
206
+ if st.session_state.tokens is not None:
207
+ st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, ne)
208
+ else:
209
+ st.warning('Please init the generation')
210
+
211
+ if left:
212
+ if st.session_state.tokens is not None:
213
+ st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, w)
214
+ else:
215
+ st.warning('Please init the generation')
216
+
217
+ if right:
218
+ if st.session_state.tokens is not None:
219
+ st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, e)
220
+ else:
221
+ st.warning('Please init the generation')
222
+
223
+ if downleft:
224
+ if st.session_state.tokens is not None:
225
+ st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, sw)
226
+ else:
227
+ st.warning('Please init the generation')
228
+
229
+ if down:
230
+ if st.session_state.tokens is not None:
231
+ st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, s)
232
+ else:
233
+ st.warning('Please init the generation')
234
+
235
+ if downright:
236
+ if st.session_state.tokens is not None:
237
+ st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, se)
238
+ else:
239
+ st.warning('Please init the generation')
240
+
241
+
242
+ if undo:
243
+ if st.session_state.tokens is not None:
244
+ if st.session_state.tokens.shape[1] >= 34:
245
+ st.session_state.image, st.session_state.tokens = undo_gen(st.session_state.tokens)
246
+ else:
247
+ st.warning('no more step to undo')
248
+ else:
249
+ st.warning('Please init the generation')
250
+
251
+
252
+
253
+ cont.image(st.session_state.image)
254
+
255
+
gist1/__pycache__/gpt.cpython-38.pyc ADDED
Binary file (6.31 kB). View file
 
gist1/__pycache__/vqvae.cpython-38.pyc ADDED
Binary file (7.42 kB). View file
 
gist1/__pycache__/vqvae_gpt.cpython-38.pyc ADDED
Binary file (7.12 kB). View file
 
gist1/gpt.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # reference
2
+ # https://blog.floydhub.com/the-transformer-in-pytorch/
3
+ # https://github.com/hyunwoongko/transformer for the transformer architecture
4
+ # https://github.com/Whiax/BERT-Transformer-Pytorch/blob/main/train.py (norm layer first)
5
+ # https://github.com/karpathy/nanoGPT
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.optim import Optimizer
11
+ from torch.optim.lr_scheduler import _LRScheduler
12
+ import numpy as np
13
+
14
+ import time
15
+
16
+ import copy
17
+
18
+ def new_gelu(x):
19
+ """
20
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
21
+ Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
22
+ """
23
+ return 0.5 * x * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
24
+
25
+
26
+ # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/fec78a687210851f055f792d45300d27cc60ae41/transformer/Modules.py
27
+ class ScaledDotProductAttention(nn.Module):
28
+ def __init__(self, temperature, dropout=0.1):
29
+ super().__init__()
30
+ self.temperature = temperature
31
+ self.dropout = nn.Dropout(dropout)
32
+
33
+ def forward(self, q, k, v, mask=None):
34
+
35
+ attn = torch.matmul(q / self.temperature, k.transpose(-2, -1))
36
+
37
+ if mask is not None:
38
+ attn = attn.masked_fill(mask == 0, -1e9)
39
+
40
+ attn = F.softmax(attn, dim=-1)
41
+ attn = self.dropout(attn)
42
+ output = torch.matmul(attn, v)
43
+
44
+ return output
45
+
46
+ class CausalMultiHeadAttention(nn.Module):
47
+ def __init__(self, heads, d_model, block_size, dropout=0.1):
48
+ super().__init__()
49
+
50
+ self.d_model = d_model
51
+ self.d_k = d_model // heads
52
+ self.h = heads
53
+
54
+ self.q_linear = nn.Linear(d_model, d_model, bias=False)
55
+ self.v_linear = nn.Linear(d_model, d_model, bias=False)
56
+ self.k_linear = nn.Linear(d_model, d_model, bias=False)
57
+
58
+
59
+ self.attention = ScaledDotProductAttention(temperature=self.d_k**0.5)
60
+
61
+ # self.dropout = nn.Dropout(dropout)
62
+ self.out = nn.Linear(d_model, d_model, bias=False)
63
+
64
+ # causal mask
65
+ self.register_buffer("causal_mask", torch.tril(torch.ones(block_size, block_size))
66
+ .view(1, 1, block_size, block_size))
67
+
68
+ self.dropout = nn.Dropout(dropout)
69
+
70
+ def forward(self, q, k, v):
71
+ bs, T, C = q.size()
72
+
73
+ # perform linear operation and split into h heads
74
+ k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
75
+ q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
76
+ v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
77
+
78
+ # transpose to get dimension of bs * h * sl * d_model
79
+
80
+ k = k.transpose(1,2)
81
+ q = q.transpose(1,2)
82
+ v = v.transpose(1,2)
83
+
84
+ # causal_mask
85
+ mask = self.causal_mask[:,:,:T,:T]
86
+
87
+ # calculate attention
88
+ attn = self.attention(q, k, v, mask)
89
+
90
+ # concatenate heads and put trough final linear layer
91
+ concat = attn.transpose(1,2).contiguous().view(bs, -1, self.d_model)
92
+
93
+ output = self.dropout(self.out(concat))
94
+
95
+ return output
96
+
97
+
98
+ class FeedForward(nn.Module):
99
+ def __init__(self, d_model, dropout=0.1):
100
+ super().__init__()
101
+ # we set d_ff as a default to 2048
102
+ self.linear_1 = nn.Linear(d_model, 4 * d_model)
103
+ self.dropout = nn.Dropout(dropout)
104
+ self.linear_2 = nn.Linear(4 * d_model, d_model)
105
+
106
+ def forward(self, x):
107
+ x = self.linear_1(x)
108
+ x = new_gelu(x)
109
+ x = self.linear_2(x)
110
+ x = self.dropout(x)
111
+ return x
112
+
113
+ # the implementation reference https://www.arxiv-vanity.com/papers/1911.03179/
114
+ class Block(nn.Module):
115
+ def __init__(self, d_model, heads, block_size, dropout=0.1):
116
+ super().__init__()
117
+ self.norm_1 = nn.LayerNorm(d_model, eps=1e-6)
118
+ self.norm_2 = nn.LayerNorm(d_model, eps=1e-6)
119
+ self.attn = CausalMultiHeadAttention(heads, d_model, block_size)
120
+ self.ff = FeedForward(d_model)
121
+ # self.dropout_1 = nn.Dropout(dropout)
122
+ # self.dropout_2 = nn.Dropout(dropout)
123
+
124
+ def forward(self, x):
125
+ # normalize
126
+ x2 = self.norm_1(x)
127
+ # compute self attention
128
+ x2 = self.attn(x2, x2, x2)
129
+ # x2 = self.dropout_1(x2)
130
+ # residual
131
+ x = x + x2
132
+ # normalize
133
+ x2= self.norm_2(x)
134
+ # positionwise feed forward network
135
+ x2 = self.ff(x2)
136
+ # x2 = self.dropout_2(x2)
137
+ # residual
138
+ x = x + x2
139
+ return x
140
+
141
+ # layer multiplier
142
+ def get_clones(module, N):
143
+ return nn.ModuleList([copy.deepcopy(module)for i in range(N)])
144
+
145
+ class GPT(nn.Module):
146
+ def __init__(self, vocab_size, d_model, N, heads, block_size=80, dropout=0.1):
147
+ super().__init__()
148
+ self.N = N
149
+ self.embed = nn.Embedding(vocab_size, d_model)
150
+ # self.pe = nn.Embedding(block_size, d_model)
151
+ self.pe = nn.Parameter(torch.zeros(1, block_size, d_model))
152
+ self.dropout = nn.Dropout(dropout)
153
+ self.layers = get_clones(Block(d_model, heads, block_size), N)
154
+ self.norm = nn.LayerNorm(d_model, eps=1e-6)
155
+ self.out = nn.Linear(d_model, vocab_size, bias=False)
156
+ self.apply(self._init_weights)
157
+
158
+ def _init_weights(self, module):
159
+ if isinstance(module, (nn.Linear, nn.Embedding)):
160
+ module.weight.data.normal_(mean=0.0, std=0.02)
161
+ if isinstance(module, nn.Linear) and module.bias is not None:
162
+ module.bias.data.zero_()
163
+ elif isinstance(module, nn.LayerNorm):
164
+ module.bias.data.zero_()
165
+ module.weight.data.fill_(1.0)
166
+
167
+ def forward(self, src):
168
+ b, t = src.size()
169
+ # pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
170
+ tok_emb = self.embed(src)
171
+ #pos_emb = self.pe(pos)
172
+ position_embeddings = self.pe[:, :t, :]
173
+ x = tok_emb + position_embeddings
174
+ x = self.dropout(x)
175
+ x = self.norm(x)
176
+ for i in range(self.N):
177
+ x = self.layers[i](x)
178
+ x = self.norm(x)
179
+ x = self.out(x)
180
+ return x
181
+
182
+
183
+ class Scheduler(_LRScheduler):
184
+ def __init__(self, optimizer, dim_embed, warmpup_steps, last_epoch=-1, verbose=False):
185
+ self.dim_embed = dim_embed
186
+ self.warmup_steps = warmpup_steps
187
+ self.num_param_groups = len(optimizer.param_groups)
188
+ super().__init__(optimizer, last_epoch, verbose)
189
+
190
+ def get_lr(self):
191
+ lr = self.dim_embed**(-0.5) * min(self._step_count**(-0.5),self._step_count * self.warmup_steps**(-1.5))
192
+ return [lr] * self.num_param_groups
gist1/vqvae.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # reference https://github.com/zalandoresearch/pytorch-vq-vae
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class VectorQuantizer(nn.Module):
8
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost):
9
+ super().__init__()
10
+
11
+ self.embedding_dim = embedding_dim
12
+ self.num_embeddings = num_embeddings
13
+
14
+ self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
15
+ self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
16
+ self.commitment_cost = commitment_cost
17
+
18
+ def forward(self, inputs):
19
+ # convert input from BCW -> BWC
20
+ inputs = inputs.permut(0, 2, 1).contiguous()
21
+ input_shape = inputs.shape
22
+
23
+ # flatten input
24
+ flat_input = inputs.view(-1, self.embedding_dim)
25
+
26
+ # calculate distances
27
+ distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
28
+ + torch.sum(self.embedding.weight**2, dim=1)
29
+ - 2 * torch.matmul(flat_input, self.embedding.weight.t()))
30
+
31
+ # encoding
32
+ encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
33
+ encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)
34
+ encodings.scatter_(1, encoding_indices, 1)
35
+
36
+ # quantize and unflatten
37
+ quantized = torch.matmul(encodings, self.embedding.weight).view(input_shape)
38
+
39
+ # loss
40
+ e_latent_loss = F.mse_loss(quantized.detach(), inputs)
41
+ q_latent_loss = F.mse_loss(quantized, input.detach())
42
+ loss = q_latent_loss + self.commitment_cost * e_latent_loss
43
+
44
+ quantized = inputs + (quantized - inputs).detach()
45
+ avg_probs = torch.mean(encodings, dim=0)
46
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
47
+
48
+ # convert quantized from BWC -> BCW
49
+ return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, encodings
50
+
51
+
52
+ class VectorQuantizerEMA(nn.Module):
53
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
54
+ super().__init__()
55
+
56
+ self.embedding_dim = embedding_dim
57
+ self.num_embeddings = num_embeddings
58
+
59
+ self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
60
+ self.embedding.weight.data.normal_()
61
+ self.commitment_cost = commitment_cost
62
+
63
+ self.register_buffer('ema_cluster_size', torch.zeros(num_embeddings))
64
+ self.ema_w = nn.Parameter(torch.Tensor(num_embeddings, self.embedding_dim))
65
+ self.ema_w.data.normal_()
66
+
67
+ self.decay = decay
68
+ self.epsilon = epsilon
69
+
70
+ def forward(self, inputs):
71
+ #convert inputs from BCW -> BWC
72
+ inputs = inputs.permute(0, 2, 1).contiguous()
73
+ input_shape = inputs.shape
74
+
75
+ # flatten input
76
+ flat_input = inputs.view(-1, self.embedding_dim)
77
+
78
+ # calculate distances
79
+ distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
80
+ + torch.sum(self.embedding.weight**2, dim=1)
81
+ - 2 * torch.matmul(flat_input, self.embedding.weight.t()))
82
+
83
+ # encoding
84
+ encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
85
+ encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)
86
+ encodings.scatter_(1, encoding_indices, 1)
87
+
88
+ # quantize and unflatten
89
+ quantized = torch.matmul(encodings, self.embedding.weight).view(input_shape)
90
+
91
+ # use EMA to update the embedding vectors
92
+ if self.training:
93
+ self.ema_cluster_size = self.ema_cluster_size * self.decay + (1 - self.decay) * torch.sum(encodings, 0)
94
+
95
+ # laplace smoothing of the cluster size
96
+ n = torch.sum(self.ema_cluster_size)
97
+ self.ema_cluster_size = self.ema_cluster_size + self.epsilon / (n + self.num_embeddings * self.epsilon * n)
98
+ dw = torch.matmul(encodings.t(), flat_input)
99
+ self.ema_w = nn.Parameter(self.ema_w * self.decay + (1 - self.decay) * dw)
100
+
101
+ self.embedding.weight = nn.Parameter(self.ema_w / self.ema_cluster_size.unsqueeze(1))
102
+
103
+ # loss
104
+ e_latent_loss = F.mse_loss(quantized.detach(), inputs)
105
+ loss = self.commitment_cost * e_latent_loss
106
+
107
+ # straight trough estimator
108
+ quantized = inputs + (quantized - inputs).detach()
109
+ avg_probs = torch.mean(encodings, dim=0)
110
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
111
+
112
+ # convert quantized from BWC -> BCW
113
+ return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, encoding_indices
114
+
115
+
116
+ class Residual(nn.Module):
117
+ def __init__(self, in_channels, num_hiddnes, num_residual_hiddens):
118
+ super().__init__()
119
+ self.block = nn.Sequential( nn.ReLU(inplace=True),
120
+ nn.Conv1d( in_channels=in_channels,
121
+ out_channels=num_residual_hiddens,
122
+ kernel_size=3, stride=1, padding=1, bias=False, padding_mode='circular'),
123
+ nn.ReLU(inplace=True),
124
+ nn.Conv1d(in_channels=num_residual_hiddens,
125
+ out_channels=num_hiddnes,
126
+ kernel_size=1, stride=1, bias=False)
127
+ )
128
+ def forward(self, x):
129
+ return x + self.block(x)
130
+
131
+
132
+ class ResidualStack(nn.Module):
133
+ def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
134
+ super().__init__()
135
+ self.num_residual_layers = num_residual_layers
136
+ self.layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
137
+ for _ in range(self.num_residual_layers)])
138
+
139
+ def forward(self, x):
140
+ for i in range(self.num_residual_layers):
141
+ x = self.layers[i](x)
142
+ return F.relu(x)
143
+
144
+
145
+ class Encoder(nn.Module):
146
+ def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
147
+ super().__init__()
148
+ # 256 -> 128
149
+ self.conv_1 = nn.Conv1d(in_channels=in_channels,
150
+ out_channels=num_hiddens//2,
151
+ kernel_size=4,
152
+ stride=2, padding=1, padding_mode='circular')
153
+ # 128 -> 64
154
+ self.conv_2 = nn.Conv1d(in_channels=num_hiddens//2,
155
+ out_channels=num_hiddens,
156
+ kernel_size=4,
157
+ stride=2, padding=1, padding_mode='circular')
158
+ # 64 -> 32
159
+ self.conv_3 = nn.Conv1d(in_channels=num_hiddens,
160
+ out_channels=num_hiddens,
161
+ kernel_size=4,
162
+ stride=2, padding=1, padding_mode='circular')
163
+ # 32 -> 16
164
+ self.conv_4 = nn.Conv1d(in_channels=num_hiddens,
165
+ out_channels=num_hiddens,
166
+ kernel_size=4,
167
+ stride=2, padding=1, padding_mode='circular')
168
+ self.conv_final = nn.Conv1d(in_channels=num_hiddens,
169
+ out_channels=num_hiddens,
170
+ kernel_size=3,
171
+ stride=1, padding=1, padding_mode='circular')
172
+ self.residual_stack = ResidualStack(in_channels=num_hiddens,
173
+ num_hiddens=num_hiddens,
174
+ num_residual_hiddens=num_residual_hiddens,
175
+ num_residual_layers=num_residual_layers)
176
+
177
+ def forward(self, inputs):
178
+ x = self.conv_1(inputs)
179
+ x = F.relu(x)
180
+
181
+ x = self.conv_2(x)
182
+ x = F.relu(x)
183
+
184
+ x = self.conv_3(x)
185
+ x = F.relu(x)
186
+
187
+ x = self.conv_4(x)
188
+ x = F.relu(x)
189
+
190
+ x = self.conv_final(x)
191
+ x = self.residual_stack(x)
192
+
193
+ return x
194
+
195
+
196
+ class Decoder(nn.Module):
197
+ def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
198
+ super().__init__()
199
+ self.conv_init = nn.Conv1d( in_channels=in_channels,
200
+ out_channels=num_hiddens,
201
+ kernel_size=3,
202
+ stride=1, padding=1)
203
+ self.residual_stack = ResidualStack(in_channels=num_hiddens,
204
+ num_hiddens=num_hiddens,
205
+ num_residual_layers=num_residual_layers,
206
+ num_residual_hiddens=num_residual_hiddens)
207
+
208
+ # 16 -> 32
209
+ self.conv_trans_0 = nn.ConvTranspose1d( in_channels=num_hiddens,
210
+ out_channels=num_hiddens,
211
+ kernel_size=4,
212
+ stride=2, padding=1)
213
+
214
+ # 32 -> 64
215
+ self.conv_trans_1 = nn.ConvTranspose1d( in_channels=num_hiddens,
216
+ out_channels=num_hiddens,
217
+ kernel_size=4,
218
+ stride=2, padding=1)
219
+ # 64 -> 128
220
+ self.conv_trans_2 = nn.ConvTranspose1d( in_channels=num_hiddens,
221
+ out_channels=num_hiddens//2,
222
+ kernel_size=4,
223
+ stride=2, padding=1)
224
+ # 128 -> 256
225
+ self.conv_trans_3 = nn.ConvTranspose1d( in_channels=num_hiddens//2,
226
+ out_channels=1,
227
+ kernel_size=4,
228
+ stride=2, padding=1)
229
+
230
+ def forward(self, inputs):
231
+ x = self.conv_init(inputs)
232
+
233
+ x = self.residual_stack(x)
234
+
235
+ x = self.conv_trans_0(x)
236
+ x = F.relu(x)
237
+
238
+ x = self.conv_trans_1(x)
239
+ x = F.relu(x)
240
+
241
+ x = self.conv_trans_2(x)
242
+ x = F.relu(x)
243
+
244
+ return self.conv_trans_3(x)
245
+
246
+
247
+ class VQVAE(nn.Module):
248
+ def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings,
249
+ embedding_dim, commitment_cost, decay=0):
250
+ super().__init__()
251
+ self.encoder = Encoder( 1, num_hiddens,
252
+ num_residual_layers,
253
+ num_residual_hiddens)
254
+ self.pre_vq_conv = nn.Conv1d( in_channels=num_hiddens,
255
+ out_channels=embedding_dim,
256
+ kernel_size=1,
257
+ stride=1)
258
+
259
+ if decay > 0.0:
260
+ self.vq = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost, decay)
261
+ else:
262
+ self.vq = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
263
+
264
+ self.decoder = Decoder( embedding_dim,
265
+ num_hiddens,
266
+ num_residual_layers,
267
+ num_residual_hiddens)
268
+
269
+ def encode(self, x):
270
+ z = self.encoder(x)
271
+ z = self.pre_vq_conv(z)
272
+ _, quantized, _, encoding_indices = self.vq(z)
273
+
274
+ return quantized, encoding_indices
275
+
276
+ def decode(self, x):
277
+ return self.decoder(x)
278
+
279
+ def forward(self, x):
280
+ z = self.encoder(x)
281
+ z = self.pre_vq_conv(z)
282
+
283
+ loss, quantized, perplexity, _ = self.vq(z)
284
+ x_recon = self.decoder(quantized)
285
+
286
+ return loss, x_recon, perplexity
287
+
288
+
289
+
290
+
gist1/vqvae_gpt.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from gist1.gpt import GPT
5
+ from gist1.vqvae import VQVAE
6
+
7
+ from utils.misc import save_params, load_params
8
+ import os
9
+ import time
10
+
11
+
12
+ class VQVAETransformer(nn.Module):
13
+ def __init__(self, args):
14
+ super().__init__()
15
+ self.vqvae = self.load_vqvae(args)
16
+ self.transformer = self.load_transformer(args)
17
+ # self.sos_token = self.get_sos_token(args)
18
+ self.pkeep = args['pkeep']
19
+ self.vqvae_vocab_size = args['vocab_size']
20
+ self.loc_vocab_size = args['loc_vocab_size']
21
+ self.block_size = args['block_size']
22
+
23
+ def load_vqvae(self, args):
24
+ # VQVAE_path = args['vqvae_checkpoint']
25
+ # VQVAE_cfg = args['vqvae_cfg']
26
+ # cfg = load_params(VQVAE_cfg)
27
+ # seed= cfg['seed']
28
+ # torch.manual_seed(seed)
29
+ num_hiddens = args['vqvae_num_hiddens']
30
+ num_residual_layers = args['vqvae_num_residual_layers']
31
+ num_residual_hiddens = args['vqvae_num_residual_hiddens']
32
+ num_embeddings = args['latent_dim']
33
+ latent_dim = args['vqvae_latent_dim']
34
+ commitment_cost = args['vqvae_commitment_cost']
35
+ decay = args['vqvae_decay']
36
+ model = VQVAE(num_hiddens, num_residual_layers, num_residual_hiddens,
37
+ num_embeddings, latent_dim, commitment_cost,
38
+ decay)
39
+ # model.load_state_dict(torch.load(VQVAE_path))
40
+ # model = model.eval()
41
+
42
+ # update args from vqvae cfg
43
+ args['vocab_size'] = num_embeddings
44
+
45
+ return model
46
+
47
+ def load_vqvae_weight(self, args):
48
+ VQVAE_path = args['vqvae_checkpoint']
49
+ self.vqvae.load_state_dict(torch.load(VQVAE_path))
50
+ self.vqvae.eval()
51
+
52
+ def load_transformer(self, args):
53
+ # seed= args['seed']
54
+ # torch.manual_seed(seed)
55
+ latent_dim = args['latent_dim']
56
+ heads = args['heads']
57
+ N = args['N']
58
+ block_size = args['block_size']
59
+ vocab_size = args['vocab_size'] + args['loc_vocab_size']
60
+ model = GPT(vocab_size, latent_dim, N, heads, block_size)
61
+ return model
62
+
63
+ @torch.no_grad()
64
+ def encode_to_z(self, x):
65
+ quantized, indices = self.vqvae.encode(x)
66
+ indices = indices.view(quantized.shape[0], -1)
67
+ return quantized, indices
68
+
69
+ @ torch.no_grad()
70
+ def z_to_isovist(self, indices):
71
+ indices[indices > self.vqvae_vocab_size-1] = self.vqvae_vocab_size-1
72
+ embedding_dim = self.vqvae.vq.embedding_dim
73
+ ix_to_vectors = self.vqvae.vq.embedding(indices).reshape(indices.shape[0], -1, embedding_dim)
74
+ ix_to_vectors = ix_to_vectors.permute(0, 2, 1)
75
+ isovist = self.vqvae.decode(ix_to_vectors)
76
+ return isovist
77
+
78
+ def loc_to_indices(self, x):
79
+ starting_index = self.vqvae_vocab_size
80
+ indices = x.long() + starting_index
81
+ return indices
82
+
83
+ def indices_to_loc(self, indices):
84
+ starting_index = self.vqvae_vocab_size
85
+ locs = indices - starting_index
86
+ locs[locs < 0] = 0
87
+ locs[locs > (self.loc_vocab_size-1)] = self.loc_vocab_size-1
88
+ return locs
89
+
90
+ def seq_encode(self, locs, isovists):
91
+ # BSW
92
+ indices_seq = []
93
+ # indices_loc = []
94
+ for i in range(isovists.shape[1]): # iterate trought the sequence
95
+ loc = locs[:, i].unsqueeze(1) # BL
96
+ indices_seq.append(self.loc_to_indices(loc))
97
+ isovist = isovists[:, i, :].unsqueeze(1) # BCW
98
+ _, indices = self.encode_to_z(isovist)
99
+ indices_seq.append(indices)
100
+ indices = torch.cat(indices_seq, dim=1)
101
+ return indices
102
+
103
+
104
+ def forward(self, indices):
105
+ device = indices.device
106
+ # indices = self.seq_encode(locs, isovists)
107
+
108
+
109
+ if self.training and self.pkeep < 1.0:
110
+ mask = torch.bernoulli(self.pkeep*torch.ones(indices.shape, device=device))
111
+ mask = mask.round().to(dtype=torch.int64)
112
+ random_indices = torch.randint_like(indices, self.vqvae_vocab_size) # doesn't include sos token
113
+ new_indices = mask*indices + (1-mask)*random_indices
114
+ else:
115
+ new_indices = indices
116
+
117
+
118
+ target = indices[:, 1:]
119
+
120
+
121
+ logits = self.transformer(new_indices[:, :-1])
122
+
123
+
124
+
125
+ return logits, target
126
+
127
+
128
+ def top_k_logits(self, logits, k):
129
+ v, ix = torch.topk(logits, k)
130
+ out = logits.clone()
131
+ out[out < v[..., [-1]]] = -float("inf")
132
+ return out
133
+
134
+
135
+
136
+ def sample(self, x, steps, temp=1.0, top_k=100, seed=None, step_size=17, zeroing=False):
137
+ device = x.device
138
+ is_train = False
139
+ if self.transformer.training == True:
140
+ is_train = True
141
+ self.transformer.eval()
142
+ block_size = self.block_size
143
+ generator = None
144
+ if seed is not None:
145
+ generator = torch.Generator(device).manual_seed(seed)
146
+ for k in range(steps):
147
+ if x.size(1) < block_size:
148
+ x_cond = x
149
+ else:
150
+ remain = step_size - (x.size(1) % step_size)
151
+ x_cond = x[:, -(block_size-remain):] # crop context if needed
152
+ if zeroing:
153
+ x_cond = x_cond.clone()
154
+ x_cond[:, 0] = self.vqvae_vocab_size
155
+ logits = self.transformer(x_cond)
156
+ logits = logits[:, -1, :] / temp
157
+
158
+ if top_k is not None:
159
+ logits = self.top_k_logits(logits, top_k)
160
+
161
+ probs = F.softmax(logits, dim = -1)
162
+
163
+ ix = torch.multinomial(probs, num_samples=1, generator=generator)
164
+
165
+ x = torch.cat((x, ix), dim=1)
166
+
167
+ if is_train == True:
168
+ self.transformer.train()
169
+
170
+ return x
171
+
172
+
173
+ def get_loc(self, ploc, dir):
174
+ if dir == 0:
175
+ loc = ploc
176
+ elif dir == 1:
177
+ loc = (ploc[0]+1, ploc[1])
178
+ elif dir == 2:
179
+ loc = (ploc[0]+1, ploc[1]+1)
180
+ elif dir == 3:
181
+ loc = (ploc[0], ploc[1]+1)
182
+ elif dir == 4:
183
+ loc = (ploc[0]-1, ploc[1]+1)
184
+ elif dir == 5:
185
+ loc = (ploc[0]-1, ploc[1])
186
+ elif dir == 6:
187
+ loc = (ploc[0]-1, ploc[1]-1)
188
+ elif dir == 7:
189
+ loc = (ploc[0], ploc[1]-1)
190
+ elif dir == 8:
191
+ loc = (ploc[0]+1, ploc[1]-1)
192
+ else:
193
+ raise NameError('Direction unknown')
194
+ return loc
195
+
196
+
197
+ def init_loc(self, x, step_size):
198
+ device = x.device
199
+ loc_dict = {}
200
+ loc = None
201
+ cached_loc = None
202
+ if x.shape[1] > 1:
203
+ steps = x.shape[1] -1
204
+ for k in range(steps):
205
+ if k % step_size == 0:
206
+ dir = x[:,k].detach().item() - self.vqvae_vocab_size
207
+ if dir == 0:
208
+ loc = (0, 0) # init loc
209
+ else:
210
+ loc = self.get_loc(loc, dir) # getloc
211
+ loc_dict[loc] = torch.empty(1,0).long().to(device)
212
+ cached_loc = loc
213
+ else:
214
+ ix = x[:,[k]]
215
+ loc_dict[cached_loc] = torch.cat((loc_dict[cached_loc], ix), dim = 1)
216
+ return loc_dict, loc
217
+
218
+ def sample_memorized(self, x, steps, temp=1.0, top_k=100, seed=None, step_size=17, zeroing=False):
219
+ device = x.device
220
+ loc_dict, loc = self.init_loc(x, step_size)
221
+ is_train = False
222
+ if self.transformer.training == True:
223
+ is_train = True
224
+ self.transformer.eval()
225
+ block_size = self.block_size
226
+ generator = None
227
+ if seed is not None:
228
+ generator = torch.Generator(device).manual_seed(seed)
229
+ is_visited = False
230
+ cache_counter = 0
231
+ # loc = None
232
+ for k in range(steps):
233
+ # check directionality
234
+ if k % step_size == 0:
235
+ dir = x[:,-1].detach().item() - self.vqvae_vocab_size
236
+ if dir == 0:
237
+ is_visited = False
238
+ loc = (0, 0) # init loc
239
+ loc_dict[loc] = torch.empty(1,0).long().to(device)
240
+ else:
241
+ loc = self.get_loc(loc, dir) # getloc
242
+ if loc in loc_dict:
243
+ is_visited = True
244
+ cache_counter = 0
245
+ else:
246
+ is_visited = False
247
+ loc_dict[loc] = torch.empty(1,0).long().to(device)
248
+
249
+
250
+ if x.size(1) < block_size:
251
+ x_cond = x
252
+ else:
253
+ remain = step_size - (x.size(1) % step_size)
254
+ x_cond = x[:, -(block_size-remain):] # crop context if needed
255
+ if zeroing:
256
+ x_cond = x_cond.clone()
257
+ x_cond[:, 0] = self.vqvae_vocab_size
258
+
259
+ if is_visited == False:
260
+ logits = self.transformer(x_cond)
261
+ logits = logits[:, -1, :] / temp
262
+
263
+ if top_k is not None:
264
+ logits = self.top_k_logits(logits, top_k)
265
+
266
+ probs = F.softmax(logits, dim = -1)
267
+ ix = torch.multinomial(probs, num_samples=1, generator=generator)
268
+ # print('this shouldnt')
269
+ loc_dict[loc] = torch.cat((loc_dict[loc], ix), dim = 1)
270
+ else:
271
+ if cache_counter == 15: #reaching end of latent code
272
+ is_visited = False
273
+ ix = loc_dict[loc][:,[cache_counter]]
274
+ # print(ix)
275
+ cache_counter += 1
276
+
277
+ x = torch.cat((x, ix), dim=1)
278
+
279
+
280
+ if is_train == True:
281
+ self.transformer.train()
282
+
283
+ return x
284
+
285
+
286
+
287
+
288
+
models/param.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "loc_vocab_size": 9,
3
+ "block_size": 255,
4
+ "batch_size": 4,
5
+ "seq_num": 8,
6
+ "seq_length": 15,
7
+ "block_seq_length": 15,
8
+ "p": 10.0,
9
+ "q": 0.001,
10
+ "loc_dim": 1,
11
+ "isovist_latent_dim": 16,
12
+ "latent_dim": 1024,
13
+ "heads": 16,
14
+ "N": 24,
15
+ "pkeep": 1.0,
16
+ "vqvae_num_hiddens": 512,
17
+ "vqvae_num_residual_layers": 4,
18
+ "vqvae_num_residual_hiddens":32,
19
+ "vqvae_latent_dim": 8,
20
+ "vqvae_commitment_cost": 0.25,
21
+ "vqvae_decay": 0.99
22
+ }
models/vqvaegpt_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b50c99dcdf274b6936bbb51903e26a401e3b5dd3bed194f1f9a7bf3b4fa8a05
3
+ size 1251118533
models/vqvaegpt_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:241ecccdbda134d226e2e765ab7278c16bc001a51ac34ea939fa89ead1fe8398
3
+ size 1251118893
models/vqvaegpt_3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eba1d15b180caa8b5b7f69028cc17f1df69cfb58c2227eb3483739069302079
3
+ size 1251118533
requirements.txt ADDED
Binary file (234 Bytes). View file
 
utils/__pycache__/dataload.cpython-38.pyc ADDED
Binary file (9.38 kB). View file
 
utils/__pycache__/isoutil.cpython-38.pyc ADDED
Binary file (17.2 kB). View file
 
utils/__pycache__/misc.cpython-38.pyc ADDED
Binary file (2.62 kB). View file
 
utils/__pycache__/s3bucket.cpython-38.pyc ADDED
Binary file (2.03 kB). View file
 
utils/isoutil.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from matplotlib.patches import Polygon
4
+ from matplotlib.collections import PatchCollection
5
+
6
+
7
+
8
+ def pol2car(rho, pi, xi, yi):
9
+ x = rho * np.cos(pi) + xi
10
+ y = rho * np.sin(pi) + yi
11
+ return (x, y)
12
+
13
+ def car2pol(xi, yi):
14
+ rho = np.sqrt(xi**2 + yi**2)
15
+ phi = np.arctan2(yi, xi)
16
+ return (rho, phi)
17
+
18
+ def car2polnorm(xi, yi):
19
+ rho = np.sqrt(xi**2 + yi**2)
20
+ phi = np.arctan2(yi, xi)
21
+ phi %= 2*np.pi
22
+ phi /= 2*np.pi
23
+ return (rho, phi)
24
+
25
+ def plot_isovist(isovists, show_axis=False, s=0.1, figsize=(5,5)):
26
+ #transpose the matrix
27
+ # isovists = np.transpose(isovists, (isovists.ndim-1, isovists.ndim-2))
28
+ plt.switch_backend('agg')
29
+ fig = plt.figure(figsize=figsize)
30
+ points = []
31
+ res = np.pi/90
32
+ isovist = isovists
33
+ for j, rho in enumerate(isovist):
34
+ if rho < 1.0:
35
+ pt = pol2car(rho, j*res, 0, 0)
36
+ points.append(pt)
37
+ x = [i[0] for i in points]
38
+ y = [i[1] for i in points]
39
+ ax = fig.add_subplot(111)
40
+ ax.set_aspect('equal')
41
+ ax.set_xlim(-1,1)
42
+ ax.set_ylim(-1,1)
43
+ if not show_axis:
44
+ ax.axis('off')
45
+ ax.scatter(x, y, s, 'black')
46
+ return fig
47
+
48
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
49
+ from matplotlib.figure import Figure
50
+
51
+
52
+ def isovist_to_img(isovist, show_axis=False, s=0.1, figsize=(5,5)):
53
+ points = []
54
+ xy = (0, 0)
55
+ res = np.pi/90
56
+ isovist = isovist + 0.5
57
+ for j, rho in enumerate(isovist):
58
+ if rho <= 2.0:
59
+ pt = pol2car(rho, j*res, xy[0], xy[1])
60
+ points.append(pt)
61
+ x = [i[0] for i in points]
62
+ y = [i[1] for i in points]
63
+ fig = plt.figure(figsize=figsize)
64
+ canvas = FigureCanvas(fig)
65
+ ax = fig.add_subplot(111)
66
+ ax.set_aspect('equal')
67
+ ax.set_xlim(-1,1)
68
+ ax.set_ylim(-1,1)
69
+ if not show_axis:
70
+ ax.axis('off')
71
+ ax.scatter(x, y, s, 'black')
72
+
73
+ canvas.draw()
74
+ image = np.fromstring(canvas.tostring_rgb(), dtype='uint8')
75
+ return image
76
+
77
+ def isovist_to_img_a(isovist, show_axis=False, s=0.1, figsize=(5,5)):
78
+ points = []
79
+ xy = (0, 0)
80
+ res = np.pi/128
81
+ isovist = isovist + 0.5
82
+ for j, rho in enumerate(isovist):
83
+ if rho <= 2.0:
84
+ pt = pol2car(rho, j*res, xy[0], xy[1])
85
+ points.append(pt)
86
+ x = [i[0] for i in points]
87
+ y = [i[1] for i in points]
88
+ fig = plt.figure(figsize=figsize)
89
+ canvas = FigureCanvas(fig)
90
+ ax = fig.add_subplot(111)
91
+ ax.set_aspect('equal')
92
+ ax.set_xlim(-1,1)
93
+ ax.set_ylim(-1,1)
94
+ if not show_axis:
95
+ ax.axis('off')
96
+ ax.scatter(x, y, s, 'black')
97
+
98
+ canvas.draw()
99
+ image = np.fromstring(canvas.tostring_rgb(), dtype='uint8')
100
+ return image
101
+
102
+ def isovist_to_cartesian(isovist, x, y, scale):
103
+ points = []
104
+ xy = (x, y)
105
+ res = np.pi/90
106
+ isovist = isovist * scale
107
+ for j, rho in enumerate(isovist):
108
+ if rho <= scale:
109
+ pt = pol2car(rho, j*res, xy[0], xy[1])
110
+ points.append(pt)
111
+ else:
112
+ pt = pol2car(scale, j*res, xy[0], xy[1])
113
+ points.append(pt)
114
+ points = np.stack(points)
115
+ return(points)
116
+
117
+ def isovist_to_cartesian_a(isovist, x, y, scale):
118
+ points = []
119
+ xy = (x, y)
120
+ res = np.pi/len(isovist)*2
121
+ isovist = isovist * scale
122
+ for j, rho in enumerate(isovist):
123
+ pt = pol2car(rho, j*res, xy[0], xy[1])
124
+ points.append(pt)
125
+ points = np.stack(points)
126
+ return(points)
127
+
128
+ def isovist_to_cartesian_b(isovist, x, y):
129
+ points = []
130
+ xy = (x, y)
131
+ res = np.pi*2
132
+ isovist = isovist
133
+ for j, rho in isovist:
134
+ pt = pol2car(rho, j*res, xy[0], xy[1])
135
+ points.append(pt)
136
+ points = np.stack(points)
137
+ return(points)
138
+
139
+ def isovist_to_cartesian_segment(isovist, x, y, scale):
140
+ points = []
141
+ segment = []
142
+ xy = (x, y)
143
+ res = np.pi/90
144
+ isovist = isovist * scale
145
+ p_rho = isovist[-1]
146
+ for j, rho in enumerate(isovist):
147
+ delta = abs(p_rho-rho)
148
+ if j == 0:
149
+ first_rho = rho
150
+ if rho < 0.98 * scale and delta < 0.05 * scale:
151
+ pt = pol2car(rho, j*res, xy[0], xy[1])
152
+ segment.append(pt)
153
+ else:
154
+ points.append(segment)
155
+ segment = []
156
+ p_rho = rho
157
+ if first_rho < 1.0 * scale and abs(rho-first_rho)< 0.05 * scale :
158
+ if len(points) > 0:
159
+ segment.extend(points[0])
160
+ points[0]=segment
161
+ else:
162
+ points.append(segment)
163
+ else:
164
+ points.append(segment)
165
+ segments = []
166
+ for i in range(len(points)):
167
+ if len(points[i])>0:
168
+ segment = np.stack(points[i])
169
+ segments.append(segment)
170
+ return(segments)
171
+
172
+ def isovist_to_cartesian_segment_a(isovist, x, y, scale, max=0.98, min = 0.1, d=0.1):
173
+ points = []
174
+ segment = []
175
+ xy = (x, y)
176
+ res = np.pi/len(isovist)*2
177
+ isovist = isovist * scale
178
+ p_rho = isovist[-1]
179
+ for j, rho in enumerate(isovist):
180
+ delta = abs(p_rho-rho)
181
+ if j == 0:
182
+ first_rho = rho
183
+ if rho < max * scale and rho > min * scale and delta < d * scale:
184
+ pt = pol2car(rho, j*res, xy[0], xy[1])
185
+ segment.append(pt)
186
+ else:
187
+ points.append(segment)
188
+ segment = []
189
+ p_rho = rho
190
+ if first_rho < max * scale and first_rho > min * scale and abs(rho-first_rho)< d * scale :
191
+ if len(points) > 0:
192
+ segment.extend(points[0])
193
+ points[0]=segment
194
+ else:
195
+ points.append(segment)
196
+ else:
197
+ points.append(segment)
198
+ segments = []
199
+ for i in range(len(points)):
200
+ if len(points[i])>0:
201
+ segment = np.stack(points[i])
202
+ segments.append(segment)
203
+ return(segments)
204
+
205
+
206
+ def isovist_to_cartesian_segment_b(isovist, x, y):
207
+ points = []
208
+ segment = []
209
+ xy = (x, y)
210
+ res = np.pi*2
211
+ isovist = isovist
212
+ p_rho = isovist[-1, 1]
213
+ _i = 0
214
+ for j, rho in isovist:
215
+ delta = abs(p_rho-rho)
216
+ if _i == 0:
217
+ first_rho = rho
218
+ if rho < 0.98 and delta < 0.1 :
219
+ pt = pol2car(rho, j*res, xy[0], xy[1])
220
+ segment.append(pt)
221
+ else:
222
+ points.append(segment)
223
+ segment = []
224
+ p_rho = rho
225
+ _i += 1
226
+ if first_rho < 0.98 and abs(rho-first_rho)< 0.1:
227
+ if len(points) > 0:
228
+ segment.extend(points[0])
229
+ points[0]=segment
230
+ else:
231
+ points.append(segment)
232
+ else:
233
+ points.append(segment)
234
+ segments = []
235
+ for i in range(len(points)):
236
+ if len(points[i])>0:
237
+ segment = np.stack(points[i])
238
+ segments.append(segment)
239
+ return(segments)
240
+
241
+
242
+ # plotting an isovist and return the numpy image
243
+ def plot_isovist_numpy(k, text=None, figsize=(8,8)):
244
+ fig, ax = plt.subplots(1,1, figsize=figsize, dpi=300)
245
+
246
+ #plot isovist
247
+ xy = isovist_to_cartesian_a(k, 0, 0, 1.0)
248
+ polygon = Polygon(xy, True)
249
+ p = PatchCollection([polygon])
250
+ p.set_facecolor('#dddddd')
251
+ p.set_edgecolor(None)
252
+ ax.add_collection(p)
253
+
254
+ # style
255
+ ax.set_aspect('equal')
256
+ lim = 1.2
257
+ ax.set_xlim(-lim,lim)
258
+ ax.set_ylim(-lim,lim)
259
+ ax.set_xticks([])
260
+ ax.set_yticks([])
261
+ ax.axis('off')
262
+ if text != None:
263
+ ax.set_title(text, size=5) # Title
264
+ fig.tight_layout()
265
+
266
+ # for plot with torchvision util
267
+ fig.canvas.draw()
268
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
269
+ w, h = fig.canvas.get_width_height()
270
+ im = data.reshape((int(h), int(w), -1))
271
+ im = im.transpose((2, 0, 1))
272
+ plt.close()
273
+ return im
274
+
275
+
276
+
277
+ # plotting isovist and boundary from and return the numpy image
278
+ def plot_isovist_boundary_numpy(isovist, boundary, figsize=(8,8)):
279
+ fig, ax = plt.subplots(1,1, figsize=figsize, dpi=300)
280
+
281
+ #plot isovist
282
+ xy = isovist_to_cartesian_a(isovist, 0, 0, 1.0)
283
+ polygon = Polygon(xy, True)
284
+ p = PatchCollection([polygon])
285
+ p.set_facecolor('#eeeeee')
286
+ p.set_edgecolor(None)
287
+ ax.add_collection(p)
288
+
289
+
290
+ #plot assumed boundary
291
+ edge_patches = []
292
+ segments = isovist_to_cartesian_segment_a(boundary, 0, 0, 1.0)
293
+ for segment in segments:
294
+ polygon = Polygon(segment, False)
295
+ edge_patches.append(polygon)
296
+ p = PatchCollection(edge_patches)
297
+ p.set_facecolor('none')
298
+ p.set_edgecolor('#000000')
299
+ p.set_linewidth(0.5)
300
+ ax.add_collection(p)
301
+
302
+ # style
303
+ ax.set_aspect('equal')
304
+ lim = 1.2
305
+ ax.set_xlim(-lim,lim)
306
+ ax.set_ylim(-lim,lim)
307
+ ax.set_xticks([])
308
+ ax.set_yticks([])
309
+ ax.axis('off')
310
+
311
+ # for plot with torchvision util
312
+ fig.canvas.draw()
313
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
314
+ w, h = fig.canvas.get_width_height()
315
+ im = data.reshape((int(h), int(w), -1))
316
+ im = im.transpose((2, 0, 1))
317
+ plt.close()
318
+ return im
319
+
320
+
321
+ # plotting two isovists (fill and edge) and return the numpy image
322
+ def plot_isovist_double_numpy(isovist1, isovist2, figsize=(8,8)):
323
+ fig, ax = plt.subplots(1,1, figsize=figsize, dpi=300)
324
+
325
+ #plot isovist1
326
+ xy = isovist_to_cartesian_a(isovist1, 0, 0, 1.0)
327
+ polygon = Polygon(xy, True)
328
+ p = PatchCollection([polygon])
329
+ p.set_facecolor('#dddddd')
330
+ p.set_edgecolor(None)
331
+ ax.add_collection(p)
332
+
333
+ #plot isovist2 as boundary
334
+ xy = isovist_to_cartesian_a(isovist2, 0, 0, 1.0)
335
+ polygon = Polygon(xy, True)
336
+ p = PatchCollection([polygon])
337
+ p.set_facecolor('none')
338
+ p.set_edgecolor('#000000')
339
+ p.set_linewidth(0.2)
340
+ ax.add_collection(p)
341
+
342
+ # style
343
+ ax.set_aspect('equal')
344
+ lim = 1.2
345
+ ax.set_xlim(-lim,lim)
346
+ ax.set_ylim(-lim,lim)
347
+ ax.set_xticks([])
348
+ ax.set_yticks([])
349
+ ax.axis('off')
350
+
351
+ # for plot with torchvision util
352
+ fig.canvas.draw()
353
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
354
+ w, h = fig.canvas.get_width_height()
355
+ im = data.reshape((int(h), int(w), -1))
356
+ im = im.transpose((2, 0, 1))
357
+ plt.close()
358
+ return im
359
+
360
+
361
+ # plotting two isovists (fill and edge) and return the numpy image
362
+ def plot_isovist_triple_numpy(isovists, locs, figsize=(8,8)):
363
+ isovist1, isovist2, isovist3 = isovists
364
+ loc1, loc2, loc3 = locs
365
+
366
+ fig, ax = plt.subplots(1,1, figsize=figsize, dpi=300)
367
+
368
+ #plot isovist1
369
+ xy = isovist_to_cartesian_a(isovist1, loc1[0], loc1[1], 1.0)
370
+ polygon = Polygon(xy, True)
371
+ p = PatchCollection([polygon])
372
+ p.set_facecolor('#ffdddd')
373
+ p.set_edgecolor(None)
374
+ ax.add_collection(p)
375
+
376
+ #plot isovist2
377
+ xy = isovist_to_cartesian_a(isovist2, loc2[0], loc2[1], 1.0)
378
+ polygon = Polygon(xy, True)
379
+ p = PatchCollection([polygon])
380
+ p.set_facecolor('#ddddff')
381
+ p.set_edgecolor(None)
382
+ ax.add_collection(p)
383
+
384
+ #plot isovist3 as boundary
385
+ xy = isovist_to_cartesian_a(isovist3, 0, 0, 1.0)
386
+ polygon = Polygon(xy, True)
387
+ p = PatchCollection([polygon])
388
+ p.set_facecolor('none')
389
+ p.set_edgecolor('#000000')
390
+ p.set_linewidth(0.2)
391
+ ax.add_collection(p)
392
+
393
+ ax.scatter([x[0] for x in locs], [x[1] for x in locs], c='k', s=8, marker='+')
394
+
395
+ annotation = ['x1', 'x2', 'y']
396
+ for i, anno in enumerate(annotation):
397
+ ax.annotate(anno, (locs[i][0]+0.1, locs[i][1]), size=8)
398
+
399
+ # style
400
+ ax.set_aspect('equal')
401
+ lim = 1.5
402
+ ax.set_xlim(-lim,lim)
403
+ ax.set_ylim(-lim,lim)
404
+ ax.set_xticks([])
405
+ ax.set_yticks([])
406
+ ax.axis('off')
407
+
408
+ # for plot with torchvision util
409
+ fig.canvas.draw()
410
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
411
+ w, h = fig.canvas.get_width_height()
412
+ im = data.reshape((int(h), int(w), -1))
413
+ im = im.transpose((2, 0, 1))
414
+ plt.close()
415
+ return im
416
+
417
+ # showing isovist sequence
418
+ def seq_show(locs, isovists, figsize=(8, 8)):
419
+ # walk trough the sequence
420
+ p_loc = np.array((0, 0))
421
+ b_segments = []
422
+ b_points = []
423
+ isovists_pts = []
424
+ res = np.pi/128
425
+ p_loc = np.array([0,0])
426
+ cartesian_locs = []
427
+ for loc, isovist in zip(locs, isovists):
428
+ rel_pos = np.asarray(pol2car(loc[0], loc[1]*2*np.pi, p_loc[0], p_loc[1]))
429
+ for j, rho in enumerate(isovist):
430
+ if rho < 0.98 :
431
+ pt = pol2car(rho, j*res, rel_pos[0], rel_pos[1])
432
+ b_points.append(pt)
433
+ segments = isovist_to_cartesian_segment_a(isovist, rel_pos[0], rel_pos[1], 1.0)
434
+ b_segments.extend(segments)
435
+ isovists_pts.append(isovist_to_cartesian_a(isovist, rel_pos[0], rel_pos[1], 1.0))
436
+ cartesian_locs.append(rel_pos)
437
+ p_loc = rel_pos
438
+
439
+ fig, ax = plt.subplots(1,1, figsize=figsize, dpi=96)
440
+
441
+
442
+ # isovists
443
+ isovist_poly = []
444
+ for isovist_pts in isovists_pts:
445
+ isovist_poly.append(Polygon(isovist_pts, True))
446
+ r = PatchCollection(isovist_poly)
447
+ r.set_facecolor('#000000')
448
+ r.set_edgecolor(None)
449
+ r.set_alpha(0.02)
450
+ ax.add_collection(r)
451
+
452
+
453
+ # isovist path
454
+ q = PatchCollection([Polygon(cartesian_locs, False)])
455
+ q.set_facecolor('none')
456
+ q.set_edgecolor('#cccccc')
457
+ q.set_linewidth(1.0)
458
+ q.set_linestyle('dashed')
459
+ ax.add_collection(q)
460
+ ax.scatter([x[0] for x in cartesian_locs], [x[1] for x in cartesian_locs], s = 6.0, c='red')
461
+
462
+ # boundaries
463
+ edge_patches = []
464
+ for segment in b_segments:
465
+ polygon = Polygon(segment, False)
466
+ edge_patches.append(polygon)
467
+ p = PatchCollection(edge_patches)
468
+ p.set_facecolor('none')
469
+ p.set_edgecolor('#000000')
470
+ p.set_linewidth(1.0)
471
+ ax.add_collection(p)
472
+ ax.scatter([x[0] for x in b_points], [x[1] for x in b_points], s = 0.05, c='k')
473
+
474
+
475
+ # style
476
+ ax.set_aspect('equal')
477
+ lim = 1.5
478
+ ax.set_xlim(-lim,lim)
479
+ ax.set_ylim(-lim,lim)
480
+ ax.set_xticks([])
481
+ ax.set_yticks([])
482
+ ax.axis('off')
483
+
484
+ return fig
485
+
486
+
487
+ # plotting isovist sequence
488
+ def plot_isovist_sequence(locs, isovists, figsize=(8,8)):
489
+ fig = seq_show(locs, isovists, figsize=figsize)
490
+
491
+ # for plot with torchvision util
492
+ fig.canvas.draw()
493
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
494
+ w, h = fig.canvas.get_width_height()
495
+ im = data.reshape((int(h), int(w), -1))
496
+ im = im.transpose((2, 0, 1))
497
+ plt.close()
498
+ return im
499
+
500
+
501
+ def index_to_loc_grid(idx, d):
502
+ if idx == 0:
503
+ return np.array((0., 0.), dtype=np.float32)
504
+ elif idx == 1:
505
+ return np.array((d, 0.), dtype=np.float32)
506
+ elif idx == 2:
507
+ return np.array((d, d), dtype=np.float32)
508
+ elif idx == 3:
509
+ return np.array((0., d), dtype=np.float32)
510
+ elif idx == 4:
511
+ return np.array((-d, d), dtype=np.float32)
512
+ elif idx == 5:
513
+ return np.array((-d, 0.), dtype=np.float32)
514
+ elif idx == 6:
515
+ return np.array((-d, -d), dtype=np.float32)
516
+ elif idx == 7:
517
+ return np.array((0., -d), dtype=np.float32)
518
+ elif idx == 8:
519
+ return np.array((d, -d), dtype=np.float32)
520
+ else:
521
+ raise NameError('Direction unknown')
522
+
523
+
524
+
525
+ # showing isovist sequence grid
526
+ def seq_show_grid(locs, isovists, d=0.2, figsize=(8, 8), center=False, lim=1.5, alpha=0.02, rad=0.9, b_width=1.0, calculate_lim=False):
527
+ # walk trough the sequence
528
+ p_loc = np.array((0, 0))
529
+ b_segments = []
530
+ b_points = []
531
+ isovists_pts = []
532
+ res = np.pi/128
533
+ cartesian_locs = []
534
+ for loc, isovist in zip(locs, isovists):
535
+ rel_pos = index_to_loc_grid(loc, d) + p_loc
536
+ for j, rho in enumerate(isovist):
537
+ if rho < rad :
538
+ pt = pol2car(rho, j*res, rel_pos[0], rel_pos[1])
539
+ b_points.append(pt)
540
+ segments = isovist_to_cartesian_segment_a(isovist, rel_pos[0], rel_pos[1], 1.0)
541
+ b_segments.extend(segments)
542
+ isovists_pts.append(isovist_to_cartesian_a(isovist, rel_pos[0], rel_pos[1], 1.0))
543
+ cartesian_locs.append(rel_pos)
544
+ p_loc = rel_pos
545
+
546
+ if len(b_points) > 0:
547
+ b_points = np.stack(b_points)
548
+ else:
549
+ b_points =[]
550
+ isovists_pts = np.stack(isovists_pts)
551
+ # b_segments = np.stack(b_segments)
552
+ cartesian_locs = np.stack(cartesian_locs)
553
+
554
+ # set graphic properties
555
+ isovist_path_width = 0.1
556
+ isovist_path_pt1 = 6.0
557
+ isovist_path_pt2 = 10.0
558
+ isovist_boundary_pt = 0.05
559
+
560
+ if center == True:
561
+
562
+ bbox = get_bbox(b_points)
563
+ center_pt = get_center_pts(bbox, np_array=True)
564
+ b_points = [ pt - center_pt for pt in b_points]
565
+ isovists_pts = [ pt - center_pt for pt in isovists_pts]
566
+ b_segments = [ pt - center_pt for pt in b_segments]
567
+ cartesian_locs = [ pt - center_pt for pt in cartesian_locs]
568
+
569
+ # resize image
570
+ if calculate_lim == True:
571
+ if bbox is not None:
572
+ max = np.max(np.abs(bbox))
573
+ else:
574
+ max = 2.0
575
+ if max > 2.0:
576
+ lim = ((max // 0.5) + 1) * 0.5
577
+ isovist_path_width *= 2.0/lim
578
+ isovist_path_pt1 *= 2.0/lim
579
+ isovist_path_pt2 *= 2.0/lim
580
+ isovist_boundary_pt *= 2.0/lim
581
+
582
+
583
+ fig, ax = plt.subplots(1,1, figsize=figsize, dpi=96)
584
+
585
+
586
+
587
+ # isovists
588
+ isovist_poly = []
589
+ for isovist_pts in isovists_pts:
590
+ isovist_poly.append(Polygon(isovist_pts, True))
591
+ r = PatchCollection(isovist_poly)
592
+ r.set_facecolor('#00aabb')
593
+ r.set_edgecolor(None)
594
+ r.set_alpha(alpha)
595
+ ax.add_collection(r)
596
+
597
+
598
+
599
+ # isovist path
600
+ q = PatchCollection([Polygon(cartesian_locs, False)])
601
+ q.set_facecolor('none')
602
+ q.set_edgecolor('red')
603
+ q.set_linewidth(isovist_path_width)
604
+ # q.set_linestyle('dashed')
605
+ ax.add_collection(q)
606
+
607
+ # start_pt
608
+ ax.scatter([x[0] for x in cartesian_locs[:1]], [x[1] for x in cartesian_locs[:1]], s = isovist_path_pt1, c='k', marker='s')
609
+
610
+ # sequence
611
+ ax.scatter([x[0] for x in cartesian_locs[1:-1]], [x[1] for x in cartesian_locs[1:-1]], s = isovist_path_pt1, c='red')
612
+
613
+ # end pt
614
+ ax.scatter([x[0] for x in cartesian_locs[-1:]], [x[1] for x in cartesian_locs[-1:]], s = isovist_path_pt2, c='k', marker='x')
615
+
616
+ # boundaries
617
+ edge_patches = []
618
+ for segment in b_segments:
619
+ if len(segment) > 5:
620
+ polygon = Polygon(segment, False)
621
+ edge_patches.append(polygon)
622
+ p = PatchCollection(edge_patches)
623
+ p.set_facecolor('none')
624
+ p.set_edgecolor('#000000')
625
+ p.set_linewidth(b_width)
626
+ ax.scatter([x[0] for x in b_points], [x[1] for x in b_points], s = isovist_boundary_pt, c='#000000',)
627
+ # ax.add_collection(p)
628
+
629
+
630
+ # style
631
+ ax.set_aspect('equal')
632
+ lim = lim
633
+ ax.set_xlim(-lim,lim)
634
+ ax.set_ylim(-lim,lim)
635
+ ax.set_xticks([])
636
+ ax.set_yticks([])
637
+ ax.axis('off')
638
+
639
+ return fig
640
+
641
+ # plotting isovist sequence grid
642
+ def plot_isovist_sequence_grid(locs, isovists, figsize=(8,8), center=False, lim=1.5, alpha=0.02, rad=0.9, b_width=1.0, calculate_lim=False):
643
+ fig = seq_show_grid(locs, isovists, figsize=figsize, center=center, lim=lim, alpha=alpha, rad=rad, b_width=b_width, calculate_lim=calculate_lim)
644
+ # for plot with torchvision util
645
+ fig.canvas.draw()
646
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
647
+ w, h = fig.canvas.get_width_height()
648
+ im = data.reshape((int(h), int(w), -1))
649
+ im = im.transpose((2, 0, 1))
650
+ plt.close()
651
+ return im
652
+
653
+ def get_bbox(pts):
654
+ if len(pts) > 0:
655
+ if type(pts) is list:
656
+ pts = np.stack(pts)
657
+ bbox = np.min(pts[:, 0]), np.max(pts[:, 0]), np.min(pts[:, 1]), np.max(pts[:, 1])
658
+ return bbox
659
+ else:
660
+ return None
661
+
662
+ def get_center_pts(bbox, np_array = False):
663
+ if bbox is not None:
664
+ center = 0.5*(bbox[0] + bbox[1]), 0.5*(bbox[2] + bbox[3])
665
+ if np_array:
666
+ center = np.asarray(center)
667
+ else:
668
+ center = np.asarray([0,0])
669
+ return center
utils/misc.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from os.path import join
3
+ import numpy as np
4
+ from PIL import Image
5
+ from utils.isoutil import *
6
+ import torch
7
+ import torchvision
8
+ import sys
9
+
10
+
11
+ class MeanTracker(object):
12
+ def __init__(self, name):
13
+ self.values = []
14
+ self.name = name
15
+
16
+ def add(self, val):
17
+ self.values.append(float(val))
18
+
19
+ def mean(self):
20
+ return np.mean(self.values)
21
+
22
+ def flush(self):
23
+ mean = self.mean()
24
+ self.values = []
25
+ return self.name, mean
26
+
27
+ def save_params(config, training_path):
28
+ save_dict_path = join(training_path, 'param.json')
29
+ with open(save_dict_path, 'w') as outfile:
30
+ json.dump(config,
31
+ outfile,
32
+ sort_keys=False,
33
+ indent=4,
34
+ separators=(',', ': '))
35
+
36
+ def load_params(config_file):
37
+ with open(config_file, 'r') as f:
38
+ data = json.load(f)
39
+ return data
40
+
41
+
42
+ def save_images(isovists, iter_num, title, sample_folder):
43
+ figs=[]
44
+ for i, x_ in enumerate(isovists):
45
+ x_ = np.squeeze(x_)
46
+ figs.append(plot_isovist_numpy(x_, figsize=(1,1)))
47
+ figs = torch.tensor(figs, dtype=torch.float)
48
+ nrow = int(np.sqrt(isovists.shape[0]))
49
+ im = torchvision.utils.make_grid(figs, normalize=True, range=(0, 255), nrow=nrow)
50
+ im = Image.fromarray(np.uint8(np.transpose(im.numpy(), (1, 2, 0))*255))
51
+ im.save(join(sample_folder, f'{title}_{iter_num:06}.jpg'))
52
+
53
+
54
+ def imshow(img):
55
+ npimg = img.numpy()
56
+ plt.figure(figsize = (30,30))
57
+ plt.imshow(np.transpose(npimg, (1, 2, 0)))
58
+ plt.axis('off')
59
+ plt.show()
60
+
61
+
62
+ def write(text):
63
+ sys.stdout.write('\n' + text)
64
+ if hasattr(sys.stdout, 'flush'):
65
+ sys.stdout.flush()
66
+
67
+