aharley commited on
Commit
77a88de
·
1 Parent(s): f7f5275

added basics

Browse files
Files changed (10) hide show
  1. nets/alltracker.py +588 -0
  2. nets/blocks.py +1304 -0
  3. utils/basic.py +144 -0
  4. utils/data.py +96 -0
  5. utils/improc.py +1103 -0
  6. utils/loss.py +220 -0
  7. utils/misc.py +100 -0
  8. utils/py.py +755 -0
  9. utils/samp.py +213 -0
  10. utils/saveload.py +65 -0
nets/alltracker.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import utils.misc
5
+ import numpy as np
6
+
7
+ from nets.blocks import CNBlockConfig, ConvNeXt, conv1x1, RelUpdateBlock, InputPadder, CorrBlock, BasicEncoder
8
+
9
+ class Net(nn.Module):
10
+ def __init__(
11
+ self,
12
+ seqlen,
13
+ use_attn=True,
14
+ use_mixer=False,
15
+ use_conv=False,
16
+ use_convb=False,
17
+ use_basicencoder=False,
18
+ use_sinmotion=False,
19
+ use_relmotion=False,
20
+ use_sinrelmotion=False,
21
+ use_feats8=False,
22
+ no_time=False,
23
+ no_space=False,
24
+ no_split=False,
25
+ no_ctx=False,
26
+ full_split=False,
27
+ corr_levels=5,
28
+ corr_radius=4,
29
+ num_blocks=3,
30
+ dim=128,
31
+ hdim=128,
32
+ init_weights=True,
33
+ ):
34
+ super(Net, self).__init__()
35
+
36
+ self.dim = dim
37
+ self.hdim = hdim
38
+
39
+ self.no_time = no_time
40
+ self.no_space = no_space
41
+ self.seqlen = seqlen
42
+ self.corr_levels = corr_levels
43
+ self.corr_radius = corr_radius
44
+ self.corr_channel = self.corr_levels * (self.corr_radius * 2 + 1) ** 2
45
+ self.num_blocks = num_blocks
46
+
47
+ self.use_feats8 = use_feats8
48
+ self.use_basicencoder = use_basicencoder
49
+ self.use_sinmotion = use_sinmotion
50
+ self.use_relmotion = use_relmotion
51
+ self.use_sinrelmotion = use_sinrelmotion
52
+ self.no_split = no_split
53
+ self.no_ctx = no_ctx
54
+ self.full_split = full_split
55
+
56
+ if use_basicencoder:
57
+ if self.full_split:
58
+ self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
59
+ self.cnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
60
+ else:
61
+ if self.no_split:
62
+ self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
63
+ else:
64
+ self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim*2, stride=8)
65
+ else:
66
+ block_setting = [
67
+ CNBlockConfig(96, 192, 3, True), # 4x
68
+ CNBlockConfig(192, 384, 3, False), # 8x
69
+ CNBlockConfig(384, None, 9, False), # 8x
70
+ ]
71
+ self.cnn = ConvNeXt(block_setting, stochastic_depth_prob=0.0, init_weights=init_weights)
72
+ if self.no_split:
73
+ self.dot_conv = conv1x1(384, dim)
74
+ else:
75
+ self.dot_conv = conv1x1(384, dim*2)
76
+
77
+ self.upsample_weight = nn.Sequential(
78
+ # convex combination of 3x3 patches
79
+ nn.Conv2d(dim, dim * 2, 3, padding=1),
80
+ nn.ReLU(inplace=True),
81
+ nn.Conv2d(dim * 2, 64 * 9, 1, padding=0)
82
+ )
83
+ self.flow_head = nn.Sequential(
84
+ nn.Conv2d(dim, 2*dim, kernel_size=3, padding=1),
85
+ nn.ReLU(inplace=True),
86
+ nn.Conv2d(2*dim, 2, kernel_size=3, padding=1)
87
+ )
88
+ self.visconf_head = nn.Sequential(
89
+ nn.Conv2d(dim, 2*dim, kernel_size=3, padding=1),
90
+ nn.ReLU(inplace=True),
91
+ nn.Conv2d(2*dim, 2, kernel_size=3, padding=1)
92
+ )
93
+
94
+ if self.use_sinrelmotion:
95
+ self.pdim = 84 # 32*2
96
+ elif self.use_relmotion:
97
+ self.pdim = 4
98
+ elif self.use_sinmotion:
99
+ self.pdim = 42
100
+ else:
101
+ self.pdim = 2
102
+
103
+ self.update_block = RelUpdateBlock(self.corr_channel, self.num_blocks, cdim=dim, hdim=hdim, pdim=self.pdim,
104
+ use_attn=use_attn, use_mixer=use_mixer, use_conv=use_conv, use_convb=use_convb,
105
+ use_layer_scale=True, no_time=no_time, no_space=no_space,
106
+ no_ctx=no_ctx)
107
+
108
+ time_line = torch.linspace(0, seqlen-1, seqlen).reshape(1, seqlen, 1)
109
+ self.register_buffer("time_emb", utils.misc.get_1d_sincos_pos_embed_from_grid(self.dim, time_line[0])) # 1,S,C
110
+
111
+
112
+ def fetch_time_embed(self, t, dtype, is_training=False):
113
+ S = self.time_emb.shape[1]
114
+ if t == S:
115
+ return self.time_emb.to(dtype)
116
+ elif t==1:
117
+ if is_training:
118
+ ind = np.random.choice(S)
119
+ return self.time_emb[:,ind:ind+1].to(dtype)
120
+ else:
121
+ return self.time_emb[:,1:2].to(dtype)
122
+ else:
123
+ time_emb = self.time_emb.float()
124
+ time_emb = F.interpolate(time_emb.permute(0, 2, 1), size=t, mode="linear").permute(0, 2, 1)
125
+ return time_emb.to(dtype)
126
+
127
+ def coords_grid(self, batch, ht, wd, device, dtype):
128
+ coords = torch.meshgrid(torch.arange(ht, device=device, dtype=dtype), torch.arange(wd, device=device, dtype=dtype), indexing='ij')
129
+ coords = torch.stack(coords[::-1], dim=0)
130
+ return coords[None].repeat(batch, 1, 1, 1)
131
+
132
+ def initialize_flow(self, img):
133
+ """ Flow is represented as difference between two coordinate grids flow = coords2 - coords1"""
134
+ N, C, H, W = img.shape
135
+ coords1 = self.coords_grid(N, H//8, W//8, device=img.device)
136
+ coords2 = self.coords_grid(N, H//8, W//8, device=img.device)
137
+ return coords1, coords2
138
+
139
+ def upsample_data(self, flow, mask):
140
+ """ Upsample [H/8, W/8, C] -> [H, W, C] using convex combination """
141
+ N, C, H, W = flow.shape
142
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
143
+ mask = torch.softmax(mask, dim=2)
144
+
145
+ up_flow = F.unfold(8 * flow, [3,3], padding=1)
146
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
147
+
148
+ up_flow = torch.sum(mask * up_flow, dim=2)
149
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
150
+
151
+ return up_flow.reshape(N, 2, 8*H, 8*W).to(flow.dtype)
152
+
153
+ def get_T_padded_images(self, images, T, S, is_training, stride=None, pad=True):
154
+ B,T,C,H,W = images.shape
155
+ indices = None
156
+ if T > 2:
157
+ step = S // 2 if stride is None else stride
158
+ indices = []
159
+ start = 0
160
+ while start + S < T:
161
+ indices.append(start)
162
+ start += step
163
+ indices.append(start)
164
+ Tpad = indices[-1]+S-T
165
+ if pad:
166
+ if is_training:
167
+ assert Tpad == 0
168
+ else:
169
+ images = images.reshape(B,1,T,C*H*W)
170
+ if Tpad > 0:
171
+ padding_tensor = images[:,:,-1:,:].expand(B,1,Tpad,C*H*W)
172
+ images = torch.cat([images, padding_tensor], dim=2)
173
+ images = images.reshape(B,T+Tpad,C,H,W)
174
+ T = T+Tpad
175
+ else:
176
+ assert T == 2
177
+ return images, T, indices
178
+
179
+ def get_fmaps(self, images_, B, T, sw, is_training):
180
+ _, _, H_pad, W_pad = images_.shape # revised HW
181
+
182
+ C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
183
+ if self.no_split:
184
+ C = self.dim
185
+
186
+ fmaps_chunk_size = 32
187
+ if (not is_training) and (T > fmaps_chunk_size):
188
+ images = images_.reshape(B,T,3,H_pad,W_pad)
189
+ fmaps = []
190
+ for t in range(0, T, fmaps_chunk_size):
191
+ images_chunk = images[:, t : t + fmaps_chunk_size]
192
+ images_chunk = images_chunk.cuda()
193
+ if self.use_basicencoder:
194
+ if self.full_split:
195
+ fmaps_chunk1 = self.fnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
196
+ fmaps_chunk2 = self.cnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
197
+ fmaps_chunk = torch.cat([fmaps_chunk1, fmaps_chunk2], axis=1)
198
+ else:
199
+ fmaps_chunk = self.fnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
200
+ else:
201
+ fmaps_chunk = self.cnn(images_chunk.reshape(-1, 3, H_pad, W_pad))
202
+ if t==0 and sw is not None and sw.save_this:
203
+ sw.summ_feat('1_model/fmap_raw', fmaps_chunk[0:1])
204
+ fmaps_chunk = self.dot_conv(fmaps_chunk) # B*T,C,H8,W8
205
+ T_chunk = images_chunk.shape[1]
206
+ fmaps.append(fmaps_chunk.reshape(B, -1, C, H8, W8))
207
+ fmaps_ = torch.cat(fmaps, dim=1).reshape(-1, C, H8, W8)
208
+ else:
209
+ if not is_training:
210
+ # sometimes we need to move things to cuda here
211
+ images_ = images_.cuda()
212
+ if self.use_basicencoder:
213
+ if self.full_split:
214
+ fmaps1_ = self.fnet(images_)
215
+ fmaps2_ = self.cnet(images_)
216
+ fmaps_ = torch.cat([fmaps1_, fmaps2_], axis=1)
217
+ else:
218
+ fmaps_ = self.fnet(images_)
219
+ else:
220
+ fmaps_ = self.cnn(images_)
221
+ if sw is not None and sw.save_this:
222
+ sw.summ_feat('1_model/fmap_raw', fmaps_[0:1])
223
+ fmaps_ = self.dot_conv(fmaps_) # B*T,C,H8,W8
224
+ return fmaps_
225
+
226
+ def forward(self, images, iters=4, sw=None, is_training=False, stride=None):
227
+ B,T,C,H,W = images.shape
228
+ S = self.seqlen
229
+ device = images.device
230
+ dtype = images.dtype
231
+
232
+ print('images', images.shape)
233
+
234
+ # images are in [0,255]
235
+ mean = torch.as_tensor([0.485, 0.456, 0.406], device=device).reshape(1,1,3,1,1).to(images.dtype)
236
+ std = torch.as_tensor([0.229, 0.224, 0.225], device=device).reshape(1,1,3,1,1).to(images.dtype)
237
+ images = images / 255.0
238
+ images = (images - mean)/std
239
+ print("a0 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
240
+
241
+ T_bak = T
242
+ if stride is not None:
243
+ pad = False
244
+ else:
245
+ pad = True
246
+ images, T, indices = self.get_T_padded_images(images, T, S, is_training, stride=stride, pad=pad)
247
+
248
+ images = images.contiguous()
249
+ images_ = images.reshape(B*T,3,H,W)
250
+ padder = InputPadder(images_.shape)
251
+ images_ = padder.pad(images_)[0]
252
+
253
+ print("a1 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
254
+
255
+ _, _, H_pad, W_pad = images_.shape # revised HW
256
+ C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
257
+ C2 = C//2
258
+ if self.no_split:
259
+ C = self.dim
260
+ C2 = C
261
+
262
+ fmaps = self.get_fmaps(images_, B, T, sw, is_training).reshape(B,T,C,H8,W8)
263
+ device = fmaps.device
264
+ print("a2 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
265
+
266
+ fmap_anchor = fmaps[:,0]
267
+
268
+ if T<=2 or is_training:
269
+ # note: collecting preds can get expensive on a long video
270
+ all_flow_preds = []
271
+ all_visconf_preds = []
272
+ else:
273
+ all_flow_preds = None
274
+ all_visconf_preds = None
275
+
276
+ if T > 2: # multiframe tracking
277
+
278
+ # we will store our final outputs in these tensors
279
+ full_flows = torch.zeros((B,T,2,H,W), dtype=dtype, device=device)
280
+ full_visconfs = torch.zeros((B,T,2,H,W), dtype=dtype, device=device)
281
+ # 1/8 resolution
282
+ full_flows8 = torch.zeros((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
283
+ full_visconfs8 = torch.zeros((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
284
+
285
+ if self.use_feats8:
286
+ full_feats8 = torch.zeros((B,T,C2,H_pad//8,W_pad//8), dtype=dtype, device=device)
287
+ visits = np.zeros((T))
288
+ print("a3 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
289
+
290
+ for ii, ind in enumerate(indices):
291
+ ara = np.arange(ind,ind+S)
292
+ print('ara', ara)
293
+ if ii < len(indices)-1:
294
+ next_ind = indices[ii+1]
295
+ next_ara = np.arange(next_ind,next_ind+S)
296
+
297
+ # print("torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024), 'ara', ara)
298
+ fmaps2 = fmaps[:,ara]
299
+ flows8 = full_flows8[:,ara].reshape(B*(S),2,H_pad//8,W_pad//8).detach()
300
+ visconfs8 = full_visconfs8[:,ara].reshape(B*(S),2,H_pad//8,W_pad//8).detach()
301
+
302
+ if self.use_feats8:
303
+ if ind==0:
304
+ feats8 = None
305
+ else:
306
+ feats8 = full_feats8[:,ara].reshape(B*(S),C2,H_pad//8,W_pad//8).detach()
307
+ else:
308
+ feats8 = None
309
+ print("a4 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
310
+
311
+ flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
312
+ fmap_anchor, fmaps2, visconfs8, iters=iters, flowfeat=feats8, flows8=flows8,
313
+ is_training=is_training)
314
+ print("a5 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
315
+
316
+ unpad_flow_predictions = []
317
+ unpad_visconf_predictions = []
318
+ for i in range(len(flow_predictions)):
319
+ flow_predictions[i] = padder.unpad(flow_predictions[i])
320
+ unpad_flow_predictions.append(flow_predictions[i].reshape(B,S,2,H,W))
321
+ visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
322
+ unpad_visconf_predictions.append(visconf_predictions[i].reshape(B,S,2,H,W))
323
+ print("a6 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
324
+
325
+ full_flows[:,ara] = unpad_flow_predictions[-1].reshape(B,S,2,H,W)
326
+ full_flows8[:,ara] = flows8.reshape(B,S,2,H_pad//8,W_pad//8)
327
+ full_visconfs[:,ara] = unpad_visconf_predictions[-1].reshape(B,S,2,H,W)
328
+ full_visconfs8[:,ara] = visconfs8.reshape(B,S,2,H_pad//8,W_pad//8)
329
+ if self.use_feats8:
330
+ full_feats8[:,ara] = feats8.reshape(B,S,C2,H_pad//8,W_pad//8)
331
+ visits[ara] += 1
332
+ print("a7 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
333
+
334
+ if is_training:
335
+ all_flow_preds.append(unpad_flow_predictions)
336
+ all_visconf_preds.append(unpad_visconf_predictions)
337
+ else:
338
+ del unpad_flow_predictions
339
+ del unpad_visconf_predictions
340
+
341
+ # for the next iter, replace empty data with nearest available preds
342
+ invalid_idx = np.where(visits==0)[0]
343
+ valid_idx = np.where(visits>0)[0]
344
+ for idx in invalid_idx:
345
+ nearest = valid_idx[np.argmin(np.abs(valid_idx - idx))]
346
+ # print('replacing %d with %d' % (idx, nearest))
347
+ full_flows8[:,idx] = full_flows8[:,nearest]
348
+ full_visconfs8[:,idx] = full_visconfs8[:,nearest]
349
+ if self.use_feats8:
350
+ full_feats8[:,idx] = full_feats8[:,nearest]
351
+ print("a8 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
352
+ else: # flow
353
+
354
+ flows8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
355
+ visconfs8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
356
+
357
+ flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
358
+ fmap_anchor, fmaps[:,1:2], visconfs8, iters=iters, flowfeat=None, flows8=flows8,
359
+ is_training=is_training)
360
+ unpad_flow_predictions = []
361
+ unpad_visconf_predictions = []
362
+ for i in range(len(flow_predictions)):
363
+ flow_predictions[i] = padder.unpad(flow_predictions[i])
364
+ all_flow_preds.append(flow_predictions[i].reshape(B,2,H,W))
365
+ visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
366
+ all_visconf_preds.append(visconf_predictions[i].reshape(B,2,H,W))
367
+ full_flows = all_flow_preds[-1].reshape(B,2,H,W)
368
+ full_visconfs = all_visconf_preds[-1].reshape(B,2,H,W)
369
+
370
+ if (not is_training) and (T > 2):
371
+ full_flows = full_flows[:,:T_bak]
372
+ full_visconfs = full_visconfs[:,:T_bak]
373
+ print("a9 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
374
+
375
+ return full_flows, full_visconfs, all_flow_preds, all_visconf_preds
376
+
377
+ def forward_sliding(self, images, iters=4, sw=None, is_training=False, window_len=None, stride=None):
378
+ B,T,C,H,W = images.shape
379
+ S = self.seqlen if window_len is None else window_len
380
+ device = images.device
381
+ dtype = images.dtype
382
+ stride = S // 2 if stride is None else stride
383
+
384
+ T_bak = T
385
+ images, T, indices = self.get_T_padded_images(images, T, S, is_training, stride)
386
+ assert stride <= S // 2
387
+
388
+ images = images.contiguous()
389
+ images_ = images.reshape(B*T,3,H,W)
390
+ padder = InputPadder(images_.shape)
391
+ images_ = padder.pad(images_)[0]
392
+
393
+ _, _, H_pad, W_pad = images_.shape # revised HW
394
+ C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
395
+ C2 = C//2
396
+ if self.no_split:
397
+ C = self.dim
398
+ C2 = C
399
+
400
+ all_flow_preds = None
401
+ all_visconf_preds = None
402
+
403
+ if T<=2:
404
+ # note: collecting preds can get expensive on a long video
405
+ all_flow_preds = []
406
+ all_visconf_preds = []
407
+
408
+ fmaps = self.get_fmaps(images_, B, T, sw, is_training).reshape(B,T,C,H8,W8)
409
+ device = fmaps.device
410
+
411
+ flows8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
412
+ visconfs8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
413
+
414
+ fmap_anchor = fmaps[:,0]
415
+
416
+ flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
417
+ fmap_anchor, fmaps[:,1:2], visconfs8, iters=iters, flowfeat=None, flows8=flows8,
418
+ is_training=is_training)
419
+ unpad_flow_predictions = []
420
+ unpad_visconf_predictions = []
421
+ for i in range(len(flow_predictions)):
422
+ flow_predictions[i] = padder.unpad(flow_predictions[i])
423
+ all_flow_preds.append(flow_predictions[i].reshape(B,2,H,W))
424
+ visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
425
+ all_visconf_preds.append(visconf_predictions[i].reshape(B,2,H,W))
426
+ full_flows = all_flow_preds[-1].reshape(B,2,H,W).detach().cpu()
427
+ full_visconfs = all_visconf_preds[-1].reshape(B,2,H,W).detach().cpu()
428
+
429
+ return full_flows, full_visconfs, all_flow_preds, all_visconf_preds
430
+
431
+ assert T > 2 # multiframe tracking
432
+
433
+ if is_training:
434
+ all_flow_preds = []
435
+ all_visconf_preds = []
436
+
437
+ # we will store our final outputs in these cpu tensors
438
+ full_flows = torch.zeros((B,T,2,H,W), dtype=dtype, device='cpu')
439
+ full_visconfs = torch.zeros((B,T,2,H,W), dtype=dtype, device='cpu')
440
+
441
+ images_ = images_.reshape(B,T,3,H_pad,W_pad)
442
+ fmap_anchor = self.get_fmaps(images_[:,:1].reshape(-1,3,H_pad,W_pad), B, 1, sw, is_training).reshape(B,C,H8,W8)
443
+ device = fmap_anchor.device
444
+ full_visited = torch.zeros((T,), dtype=torch.bool, device=device)
445
+
446
+ for ii, ind in enumerate(indices):
447
+ ara = np.arange(ind,ind+S)
448
+ if ii == 0:
449
+ flows8 = torch.zeros((B,S,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
450
+ visconfs8 = torch.zeros((B,S,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
451
+ fmaps2 = self.get_fmaps(images_[:,ara].reshape(-1,3,H_pad,W_pad), B, S, sw, is_training).reshape(B,S,C,H8,W8)
452
+ else:
453
+ flows8 = torch.cat([flows8[:,stride:stride+S//2], flows8[:,stride+S//2-1:stride+S//2].repeat(1,S//2,1,1,1)], dim=1)
454
+ visconfs8 = torch.cat([visconfs8[:,stride:stride+S//2], visconfs8[:,stride+S//2-1:stride+S//2].repeat(1,S//2,1,1,1)], dim=1)
455
+ fmaps2 = torch.cat([fmaps2[:,stride:stride+S//2],
456
+ self.get_fmaps(images_[:,np.arange(ind+S//2,ind+S)].reshape(-1,3,H_pad,W_pad), B, S//2, sw, is_training).reshape(B,S//2,C,H8,W8)], dim=1)
457
+
458
+ flows8 = flows8.reshape(B*S,2,H_pad//8,W_pad//8).detach()
459
+ visconfs8 = visconfs8.reshape(B*S,2,H_pad//8,W_pad//8).detach()
460
+
461
+ flow_predictions, visconf_predictions, flows8, visconfs8, _ = self.forward_window(
462
+ fmap_anchor, fmaps2, visconfs8, iters=iters, flowfeat=None, flows8=flows8,
463
+ is_training=is_training)
464
+
465
+ unpad_flow_predictions = []
466
+ unpad_visconf_predictions = []
467
+ for i in range(len(flow_predictions)):
468
+ flow_predictions[i] = padder.unpad(flow_predictions[i])
469
+ unpad_flow_predictions.append(flow_predictions[i].reshape(B,S,2,H,W))
470
+ visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
471
+ unpad_visconf_predictions.append(visconf_predictions[i].reshape(B,S,2,H,W))
472
+
473
+ current_visiting = torch.zeros((T,), dtype=torch.bool, device=device)
474
+ current_visiting[ara] = True
475
+
476
+ to_fill = current_visiting & (~full_visited)
477
+ to_fill_sum = to_fill.sum().item()
478
+ full_flows[:,to_fill] = unpad_flow_predictions[-1].reshape(B,S,2,H,W)[:,-to_fill_sum:].detach().cpu()
479
+ full_visconfs[:,to_fill] = unpad_visconf_predictions[-1].reshape(B,S,2,H,W)[:,-to_fill_sum:].detach().cpu()
480
+ full_visited |= current_visiting
481
+
482
+ if is_training:
483
+ all_flow_preds.append(unpad_flow_predictions)
484
+ all_visconf_preds.append(unpad_visconf_predictions)
485
+ else:
486
+ del unpad_flow_predictions
487
+ del unpad_visconf_predictions
488
+
489
+ flows8 = flows8.reshape(B,S,2,H_pad//8,W_pad//8)
490
+ visconfs8 = visconfs8.reshape(B,S,2,H_pad//8,W_pad//8)
491
+
492
+ if not is_training:
493
+ full_flows = full_flows[:,:T_bak]
494
+ full_visconfs = full_visconfs[:,:T_bak]
495
+
496
+ return full_flows, full_visconfs, all_flow_preds, all_visconf_preds
497
+
498
+ def forward_window(self, fmap1_single, fmaps2, visconfs8, iters=None, flowfeat=None, flows8=None, sw=None, is_training=False):
499
+ B,S,C,H8,W8 = fmaps2.shape
500
+ device = fmaps2.device
501
+ dtype = fmaps2.dtype
502
+
503
+ flow_predictions = []
504
+ visconf_predictions = []
505
+
506
+ fmap1 = fmap1_single.unsqueeze(1).repeat(1,S,1,1,1) # B,S,C,H,W
507
+ fmap1 = fmap1.reshape(B*(S),C,H8,W8).contiguous()
508
+
509
+ fmap2 = fmaps2.reshape(B*(S),C,H8,W8).contiguous()
510
+
511
+ visconfs8 = visconfs8.reshape(B*(S),2,H8,W8).contiguous()
512
+
513
+ corr_fn = CorrBlock(fmap1, fmap2, self.corr_levels, self.corr_radius)
514
+
515
+ coords1 = self.coords_grid(B*(S), H8, W8, device=fmap1.device, dtype=dtype)
516
+
517
+ if self.no_split:
518
+ flowfeat, ctxfeat = fmap1.clone(), fmap1.clone()
519
+ else:
520
+ if flowfeat is not None:
521
+ _, ctxfeat = torch.split(fmap1, [self.dim, self.dim], dim=1)
522
+ else:
523
+ flowfeat, ctxfeat = torch.split(fmap1, [self.dim, self.dim], dim=1)
524
+
525
+ # add pos emb to ctxfeat (and not flowfeat), since ctxfeat is untouched across iters
526
+ time_emb = self.fetch_time_embed(S, ctxfeat.dtype, is_training).reshape(1,S,self.dim,1,1).repeat(B,1,1,1,1)
527
+ ctxfeat = ctxfeat + time_emb.reshape(B*S,self.dim,1,1)
528
+
529
+ if self.no_ctx:
530
+ flowfeat = flowfeat + time_emb.reshape(B*S,self.dim,1,1)
531
+
532
+ for itr in range(iters):
533
+ _, _, H8, W8 = flows8.shape
534
+ flows8 = flows8.detach()
535
+ coords2 = (coords1 + flows8).detach() # B*S,2,H,W
536
+ corr = corr_fn(coords2).to(dtype)
537
+
538
+ if self.use_relmotion or self.use_sinrelmotion:
539
+ coords_ = coords2.reshape(B,S,2,H8*W8).permute(0,1,3,2) # B,S,H8*W8,2
540
+ rel_coords_forward = coords_[:, :-1] - coords_[:, 1:]
541
+ rel_coords_backward = coords_[:, 1:] - coords_[:, :-1]
542
+ rel_coords_forward = torch.nn.functional.pad(
543
+ rel_coords_forward, (0, 0, 0, 0, 0, 1) # pad the 3rd-last dim (S) by (0,1)
544
+ )
545
+ rel_coords_backward = torch.nn.functional.pad(
546
+ rel_coords_backward, (0, 0, 0, 0, 1, 0) # pad the 3rd-last dim (S) by (1,0)
547
+ )
548
+ rel_coords = torch.cat([rel_coords_forward, rel_coords_backward], dim=-1) # B,S,H8*W8,4
549
+
550
+ if self.use_sinrelmotion:
551
+ rel_pos_emb_input = utils.misc.posenc(
552
+ rel_coords,
553
+ min_deg=0,
554
+ max_deg=10,
555
+ ) # B,S,H*W,pdim
556
+ motion = rel_pos_emb_input.reshape(B*S,H8,W8,self.pdim).permute(0,3,1,2).to(dtype) # B*S,pdim,H8,W8
557
+ else:
558
+ motion = rel_coords.reshape(B*S,H8,W8,4).permute(0,3,1,2).to(dtype) # B*S,4,H8,W8
559
+
560
+ else:
561
+ if self.use_sinmotion:
562
+ pos_emb_input = utils.misc.posenc(
563
+ flows8.reshape(B,S,H8*W8,2),
564
+ min_deg=0,
565
+ max_deg=10,
566
+ ) # B,S,H*W,pdim
567
+ motion = pos_emb_input.reshape(B*S,H8,W8,self.pdim).permute(0,3,1,2).to(dtype) # B*S,pdim,H8,W8
568
+ else:
569
+ motion = flows8
570
+
571
+ flowfeat = self.update_block(flowfeat, ctxfeat, visconfs8, corr, motion, S)
572
+ flow_update = self.flow_head(flowfeat)
573
+ visconf_update = self.visconf_head(flowfeat)
574
+ weight_update = .25 * self.upsample_weight(flowfeat)
575
+ flows8 = flows8 + flow_update
576
+ visconfs8 = visconfs8 + visconf_update
577
+ flow_up = self.upsample_data(flows8, weight_update)
578
+ visconf_up = self.upsample_data(visconfs8, weight_update)
579
+ if not is_training: # clear mem
580
+ flow_predictions = []
581
+ visconf_predictions = []
582
+ flow_predictions.append(flow_up)
583
+ visconf_predictions.append(visconf_up)
584
+
585
+ return flow_predictions, visconf_predictions, flows8, visconfs8, flowfeat
586
+
587
+
588
+
nets/blocks.py ADDED
@@ -0,0 +1,1304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import nn, Tensor
5
+ from itertools import repeat
6
+ import collections
7
+ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
8
+ from functools import partial
9
+ import einops
10
+ import math
11
+ from torchvision.ops.misc import Conv2dNormActivation, Permute
12
+ from torchvision.ops.stochastic_depth import StochasticDepth
13
+
14
+ def _ntuple(n):
15
+ def parse(x):
16
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
17
+ return tuple(x)
18
+ return tuple(repeat(x, n))
19
+ return parse
20
+
21
+ def exists(val):
22
+ return val is not None
23
+
24
+ def default(val, d):
25
+ return val if exists(val) else d
26
+
27
+ to_2tuple = _ntuple(2)
28
+
29
+ class InputPadder:
30
+ """ Pads images such that dimensions are divisible by a certain stride """
31
+ def __init__(self, dims, mode='sintel'):
32
+ self.ht, self.wd = dims[-2:]
33
+ pad_ht = (((self.ht // 64) + 1) * 64 - self.ht) % 64
34
+ pad_wd = (((self.wd // 64) + 1) * 64 - self.wd) % 64
35
+ if mode == 'sintel':
36
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
37
+ else:
38
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
39
+
40
+ def pad(self, *inputs):
41
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
42
+
43
+ def unpad(self, x):
44
+ ht, wd = x.shape[-2:]
45
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
46
+ return x[..., c[0]:c[1], c[2]:c[3]]
47
+
48
+ def bilinear_sampler(
49
+ input, coords,
50
+ align_corners=True,
51
+ padding_mode="border",
52
+ normalize_coords=True):
53
+ # func from mattie (oct9)
54
+ if input.ndim not in [4, 5]:
55
+ raise ValueError("input must be 4D or 5D.")
56
+
57
+ if input.ndim == 4 and not coords.ndim == 4:
58
+ raise ValueError("input is 4D, but coords is not 4D.")
59
+
60
+ if input.ndim == 5 and not coords.ndim == 5:
61
+ raise ValueError("input is 5D, but coords is not 5D.")
62
+
63
+ if coords.ndim == 5:
64
+ coords = coords[..., [1, 2, 0]] # t x y -> x y t to match what grid_sample() expects.
65
+
66
+ if normalize_coords:
67
+ if align_corners:
68
+ # Normalize coordinates from [0, W/H - 1] to [-1, 1].
69
+ coords = (
70
+ coords
71
+ * torch.tensor([2 / max(size - 1, 1) for size in reversed(input.shape[2:])], device=coords.device)
72
+ - 1
73
+ )
74
+ else:
75
+ # Normalize coordinates from [0, W/H] to [-1, 1].
76
+ coords = coords * torch.tensor([2 / size for size in reversed(input.shape[2:])], device=coords.device) - 1
77
+
78
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
79
+
80
+
81
+ class CorrBlock:
82
+ def __init__(self, fmap1, fmap2, corr_levels, corr_radius):
83
+ self.num_levels = corr_levels
84
+ self.radius = corr_radius
85
+ self.corr_pyramid = []
86
+ # all pairs correlation
87
+ for i in range(self.num_levels):
88
+ corr = CorrBlock.corr(fmap1, fmap2, 1)
89
+ batch, h1, w1, dim, h2, w2 = corr.shape
90
+ corr = corr.reshape(batch*h1*w1, dim, h2, w2)
91
+ fmap2 = F.interpolate(fmap2, scale_factor=0.5, mode='area')
92
+ # print('corr', corr.shape)
93
+ self.corr_pyramid.append(corr)
94
+
95
+ def __call__(self, coords, dilation=None):
96
+ r = self.radius
97
+ coords = coords.permute(0, 2, 3, 1)
98
+ batch, h1, w1, _ = coords.shape
99
+
100
+ if dilation is None:
101
+ dilation = torch.ones(batch, 1, h1, w1, device=coords.device)
102
+
103
+ out_pyramid = []
104
+ for i in range(self.num_levels):
105
+ corr = self.corr_pyramid[i]
106
+ device = coords.device
107
+ dx = torch.linspace(-r, r, 2*r+1, device=device)
108
+ dy = torch.linspace(-r, r, 2*r+1, device=device)
109
+ delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
110
+ delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
111
+ delta_lvl = delta_lvl * dilation.view(batch * h1 * w1, 1, 1, 1)
112
+ centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
113
+ coords_lvl = centroid_lvl + delta_lvl
114
+ corr = bilinear_sampler(corr, coords_lvl)
115
+ corr = corr.view(batch, h1, w1, -1)
116
+ out_pyramid.append(corr)
117
+
118
+ out = torch.cat(out_pyramid, dim=-1)
119
+ out = out.permute(0, 3, 1, 2).contiguous().float()
120
+ return out
121
+
122
+ @staticmethod
123
+ def corr(fmap1, fmap2, num_head):
124
+ batch, dim, h1, w1 = fmap1.shape
125
+ h2, w2 = fmap2.shape[2:]
126
+ fmap1 = fmap1.view(batch, num_head, dim // num_head, h1*w1)
127
+ fmap2 = fmap2.view(batch, num_head, dim // num_head, h2*w2)
128
+ corr = fmap1.transpose(2, 3) @ fmap2
129
+ corr = corr.reshape(batch, num_head, h1, w1, h2, w2).permute(0, 2, 3, 1, 4, 5)
130
+ return corr / torch.sqrt(torch.tensor(dim).float())
131
+
132
+ def conv1x1(in_planes, out_planes, stride=1):
133
+ """1x1 convolution without padding"""
134
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0)
135
+
136
+ def conv3x3(in_planes, out_planes, stride=1):
137
+ """3x3 convolution with padding"""
138
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1)
139
+
140
+ class LayerNorm2d(nn.LayerNorm):
141
+ def forward(self, x: Tensor) -> Tensor:
142
+ x = x.permute(0, 2, 3, 1)
143
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
144
+ x = x.permute(0, 3, 1, 2)
145
+ return x
146
+
147
+ class CNBlock1d(nn.Module):
148
+ def __init__(
149
+ self,
150
+ dim,
151
+ output_dim,
152
+ layer_scale: float = 1e-6,
153
+ stochastic_depth_prob: float = 0,
154
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
155
+ dense=True,
156
+ use_attn=True,
157
+ use_mixer=False,
158
+ use_conv=False,
159
+ use_convb=False,
160
+ use_layer_scale=True,
161
+ ) -> None:
162
+ super().__init__()
163
+ self.dense = dense
164
+ self.use_attn = use_attn
165
+ self.use_mixer = use_mixer
166
+ self.use_conv = use_conv
167
+ self.use_layer_scale = use_layer_scale
168
+
169
+ if use_attn:
170
+ assert not use_mixer
171
+ assert not use_conv
172
+ assert not use_convb
173
+
174
+ if norm_layer is None:
175
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
176
+
177
+ if use_attn:
178
+ num_heads = 8
179
+ self.block = AttnBlock(
180
+ hidden_size=dim,
181
+ num_heads=num_heads,
182
+ mlp_ratio=4,
183
+ attn_class=Attention,
184
+ )
185
+ elif use_mixer:
186
+ self.block = MLPMixerBlock(
187
+ S=16,
188
+ dim=dim,
189
+ depth=1,
190
+ expansion_factor=2,
191
+ )
192
+ elif use_conv:
193
+ self.block = nn.Sequential(
194
+ nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'),
195
+ Permute([0, 2, 1]),
196
+ norm_layer(dim),
197
+ nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
198
+ nn.GELU(),
199
+ nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
200
+ Permute([0, 2, 1]),
201
+ )
202
+ elif use_convb:
203
+ self.block = nn.Sequential(
204
+ nn.Conv1d(dim, dim, kernel_size=3, padding=1, bias=True, padding_mode='zeros'),
205
+ Permute([0, 2, 1]),
206
+ norm_layer(dim),
207
+ nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
208
+ nn.GELU(),
209
+ nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
210
+ Permute([0, 2, 1]),
211
+ )
212
+ else:
213
+ assert(False) # choose attn, mixer, or conv please
214
+
215
+ if self.use_layer_scale:
216
+ self.layer_scale = nn.Parameter(torch.ones(dim, 1) * layer_scale)
217
+ else:
218
+ self.layer_scale = 1.0
219
+
220
+ self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
221
+
222
+ if output_dim != dim:
223
+ self.final = nn.Conv1d(dim, output_dim, kernel_size=1, padding=0)
224
+ else:
225
+ self.final = nn.Identity()
226
+
227
+ def forward(self, input, S=None):
228
+ if self.dense:
229
+ assert S is not None
230
+ BS,C,H,W = input.shape
231
+ B = BS//S
232
+
233
+ input = einops.rearrange(input, '(b s) c h w -> (b h w) c s', b=B, s=S, c=C, h=H, w=W)
234
+
235
+ if self.use_mixer or self.use_attn:
236
+ # mixer/transformer blocks want B,S,C
237
+ result = self.layer_scale * self.block(input.permute(0,2,1)).permute(0,2,1)
238
+ else:
239
+ result = self.layer_scale * self.block(input)
240
+ result = self.stochastic_depth(result)
241
+ result += input
242
+ result = self.final(result)
243
+
244
+ result = einops.rearrange(result, '(b h w) c s -> (b s) c h w', b=B, s=S, c=C, h=H, w=W)
245
+ else:
246
+ B,S,C = input.shape
247
+
248
+ if S<7:
249
+ return input
250
+
251
+ input = einops.rearrange(input, 'b s c -> b c s', b=B, s=S, c=C)
252
+
253
+ result = self.layer_scale * self.block(input)
254
+ result = self.stochastic_depth(result)
255
+ result += input
256
+
257
+ result = self.final(result)
258
+
259
+ result = einops.rearrange(result, 'b c s -> b s c', b=B, s=S, c=C)
260
+
261
+ return result
262
+
263
+ class CNBlock2d(nn.Module):
264
+ def __init__(
265
+ self,
266
+ dim,
267
+ output_dim,
268
+ layer_scale: float = 1e-6,
269
+ stochastic_depth_prob: float = 0,
270
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
271
+ use_layer_scale=True,
272
+ ) -> None:
273
+ super().__init__()
274
+ self.use_layer_scale = use_layer_scale
275
+ if norm_layer is None:
276
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
277
+
278
+ self.block = nn.Sequential(
279
+ nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'),
280
+ Permute([0, 2, 3, 1]),
281
+ norm_layer(dim),
282
+ nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
283
+ nn.GELU(),
284
+ nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
285
+ Permute([0, 3, 1, 2]),
286
+ )
287
+ if self.use_layer_scale:
288
+ self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
289
+ else:
290
+ self.layer_scale = 1.0
291
+ self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
292
+
293
+ if output_dim != dim:
294
+ self.final = nn.Conv2d(dim, output_dim, kernel_size=1, padding=0)
295
+ else:
296
+ self.final = nn.Identity()
297
+
298
+ def forward(self, input, S=None):
299
+ result = self.layer_scale * self.block(input)
300
+ result = self.stochastic_depth(result)
301
+ result += input
302
+ result = self.final(result)
303
+ return result
304
+
305
+ class CNBlockConfig:
306
+ # Stores information listed at Section 3 of the ConvNeXt paper
307
+ def __init__(
308
+ self,
309
+ input_channels: int,
310
+ out_channels: Optional[int],
311
+ num_layers: int,
312
+ downsample: bool,
313
+ ) -> None:
314
+ self.input_channels = input_channels
315
+ self.out_channels = out_channels
316
+ self.num_layers = num_layers
317
+ self.downsample = downsample
318
+
319
+ def __repr__(self) -> str:
320
+ s = self.__class__.__name__ + "("
321
+ s += "input_channels={input_channels}"
322
+ s += ", out_channels={out_channels}"
323
+ s += ", num_layers={num_layers}"
324
+ s += ", downsample={downsample}"
325
+ s += ")"
326
+ return s.format(**self.__dict__)
327
+
328
+ class ConvNeXt(nn.Module):
329
+ def __init__(
330
+ self,
331
+ block_setting: List[CNBlockConfig],
332
+ stochastic_depth_prob: float = 0.0,
333
+ layer_scale: float = 1e-6,
334
+ num_classes: int = 1000,
335
+ block: Optional[Callable[..., nn.Module]] = None,
336
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
337
+ init_weights=True):
338
+ super().__init__()
339
+
340
+ self.init_weights = init_weights
341
+
342
+ if not block_setting:
343
+ raise ValueError("The block_setting should not be empty")
344
+ elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
345
+ raise TypeError("The block_setting should be List[CNBlockConfig]")
346
+
347
+ if block is None:
348
+ block = CNBlock2d
349
+
350
+ if norm_layer is None:
351
+ norm_layer = partial(LayerNorm2d, eps=1e-6)
352
+
353
+ layers: List[nn.Module] = []
354
+
355
+ # Stem
356
+ firstconv_output_channels = block_setting[0].input_channels
357
+ layers.append(
358
+ Conv2dNormActivation(
359
+ 3,
360
+ firstconv_output_channels,
361
+ kernel_size=4,
362
+ stride=4,
363
+ padding=0,
364
+ norm_layer=norm_layer,
365
+ activation_layer=None,
366
+ bias=True,
367
+ )
368
+ )
369
+
370
+ total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
371
+ stage_block_id = 0
372
+ for cnf in block_setting:
373
+ # Bottlenecks
374
+ stage: List[nn.Module] = []
375
+ for _ in range(cnf.num_layers):
376
+ # adjust stochastic depth probability based on the depth of the stage block
377
+ sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
378
+ stage.append(block(cnf.input_channels, cnf.input_channels, layer_scale, sd_prob))
379
+ stage_block_id += 1
380
+ layers.append(nn.Sequential(*stage))
381
+ if cnf.out_channels is not None:
382
+ if cnf.downsample:
383
+ layers.append(
384
+ nn.Sequential(
385
+ norm_layer(cnf.input_channels),
386
+ nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
387
+ )
388
+ )
389
+ else:
390
+ # we convert the 2x2 downsampling layer into a 3x3 with dilation2 and replicate padding.
391
+ # replicate padding compensates for the fact that this kernel never saw zero-padding.
392
+ layers.append(
393
+ nn.Sequential(
394
+ norm_layer(cnf.input_channels),
395
+ nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=3, stride=1, padding=2, dilation=2, padding_mode='zeros'),
396
+ )
397
+ )
398
+
399
+ self.features = nn.Sequential(*layers)
400
+
401
+ # self.final_conv = conv1x1(block_setting[-1].input_channels, output_dim)
402
+
403
+ for m in self.modules():
404
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
405
+ nn.init.trunc_normal_(m.weight, std=0.02)
406
+ if m.bias is not None:
407
+ nn.init.zeros_(m.bias)
408
+
409
+ if self.init_weights:
410
+ from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights
411
+ pretrained_dict = convnext_tiny(weights=ConvNeXt_Tiny_Weights.DEFAULT).state_dict()
412
+ # from torchvision.models import convnext_base, ConvNeXt_Base_Weights
413
+ # pretrained_dict = convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT).state_dict()
414
+ model_dict = self.state_dict()
415
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
416
+
417
+ for k, v in pretrained_dict.items():
418
+ if k == 'features.4.1.weight': # this is the layer normally in charge of 2x2 downsampling
419
+ # convert to 3x3 filter
420
+ pretrained_dict[k] = F.interpolate(v, (3, 3), mode='bicubic', align_corners=True) * (4/9.0)
421
+
422
+ model_dict.update(pretrained_dict)
423
+ self.load_state_dict(model_dict, strict=False)
424
+
425
+
426
+ def _forward_impl(self, x: Tensor) -> Tensor:
427
+ x = self.features(x)
428
+ # x = self.final_conv(x)
429
+ return x
430
+
431
+ def forward(self, x: Tensor) -> Tensor:
432
+ return self._forward_impl(x)
433
+
434
+ class Mlp(nn.Module):
435
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
436
+
437
+ def __init__(
438
+ self,
439
+ in_features,
440
+ hidden_features=None,
441
+ out_features=None,
442
+ act_layer=nn.GELU,
443
+ norm_layer=None,
444
+ bias=True,
445
+ drop=0.0,
446
+ use_conv=False,
447
+ ):
448
+ super().__init__()
449
+ out_features = out_features or in_features
450
+ hidden_features = hidden_features or in_features
451
+ bias = to_2tuple(bias)
452
+ drop_probs = to_2tuple(drop)
453
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
454
+
455
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
456
+ self.act = act_layer()
457
+ self.drop1 = nn.Dropout(drop_probs[0])
458
+ self.norm = (
459
+ norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
460
+ )
461
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
462
+ self.drop2 = nn.Dropout(drop_probs[1])
463
+
464
+ def forward(self, x):
465
+ x = self.fc1(x)
466
+ x = self.act(x)
467
+ x = self.drop1(x)
468
+ x = self.fc2(x)
469
+ x = self.drop2(x)
470
+ return x
471
+
472
+ class Attention(nn.Module):
473
+ def __init__(
474
+ self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False
475
+ ):
476
+ super().__init__()
477
+ inner_dim = dim_head * num_heads
478
+ context_dim = default(context_dim, query_dim)
479
+ self.scale = dim_head**-0.5
480
+ self.heads = num_heads
481
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
482
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
483
+ self.to_out = nn.Linear(inner_dim, query_dim)
484
+
485
+ def forward(self, x, context=None, attn_bias=None):
486
+ B, N1, C = x.shape
487
+ H = self.heads
488
+ q = self.to_q(x)
489
+ context = default(context, x)
490
+ k, v = self.to_kv(context).chunk(2, dim=-1)
491
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))
492
+ x = F.scaled_dot_product_attention(q, k, v) # scale default is already dim^-0.5
493
+ x = einops.rearrange(x, 'b h n d -> b n (h d)')
494
+ return self.to_out(x)
495
+
496
+ class CrossAttnBlock(nn.Module):
497
+ def __init__(
498
+ self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs
499
+ ):
500
+ super().__init__()
501
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
502
+ self.norm_context = nn.LayerNorm(hidden_size)
503
+ self.cross_attn = Attention(
504
+ hidden_size,
505
+ context_dim=context_dim,
506
+ num_heads=num_heads,
507
+ qkv_bias=True,
508
+ **block_kwargs
509
+ )
510
+
511
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
512
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
513
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
514
+ self.mlp = Mlp(
515
+ in_features=hidden_size,
516
+ hidden_features=mlp_hidden_dim,
517
+ act_layer=approx_gelu,
518
+ drop=0,
519
+ )
520
+
521
+ def forward(self, x, context, mask=None):
522
+ attn_bias = None
523
+ if mask is not None:
524
+ if mask.shape[1] == x.shape[1]:
525
+ mask = mask[:, None, :, None].expand(
526
+ -1, self.cross_attn.heads, -1, context.shape[1]
527
+ )
528
+ else:
529
+ mask = mask[:, None, None].expand(
530
+ -1, self.cross_attn.heads, x.shape[1], -1
531
+ )
532
+
533
+ max_neg_value = -torch.finfo(x.dtype).max
534
+ attn_bias = (~mask) * max_neg_value
535
+ x = x + self.cross_attn(
536
+ self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
537
+ )
538
+ x = x + self.mlp(self.norm2(x))
539
+ return x
540
+
541
+ class AttnBlock(nn.Module):
542
+ def __init__(
543
+ self,
544
+ hidden_size,
545
+ num_heads,
546
+ attn_class: Callable[..., nn.Module] = Attention,
547
+ mlp_ratio=4.0,
548
+ **block_kwargs
549
+ ):
550
+ super().__init__()
551
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
552
+ self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, dim_head=hidden_size//num_heads)
553
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
554
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
555
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
556
+ self.mlp = Mlp(
557
+ in_features=hidden_size,
558
+ hidden_features=mlp_hidden_dim,
559
+ act_layer=approx_gelu,
560
+ drop=0,
561
+ )
562
+
563
+ def forward(self, x, mask=None):
564
+ attn_bias = mask
565
+ if mask is not None:
566
+ mask = (
567
+ (mask[:, None] * mask[:, :, None])
568
+ .unsqueeze(1)
569
+ .expand(-1, self.attn.num_heads, -1, -1)
570
+ )
571
+ max_neg_value = -torch.finfo(x.dtype).max
572
+ attn_bias = (~mask) * max_neg_value
573
+
574
+ x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
575
+ x = x + self.mlp(self.norm2(x))
576
+ return x
577
+
578
+
579
+ class ResidualBlock(nn.Module):
580
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
581
+ super(ResidualBlock, self).__init__()
582
+
583
+ self.conv1 = nn.Conv2d(
584
+ in_planes,
585
+ planes,
586
+ kernel_size=3,
587
+ padding=1,
588
+ stride=stride,
589
+ padding_mode="zeros",
590
+ )
591
+ self.conv2 = nn.Conv2d(
592
+ planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
593
+ )
594
+ self.relu = nn.ReLU(inplace=True)
595
+
596
+ num_groups = planes // 8
597
+
598
+ if norm_fn == "group":
599
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
600
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
601
+ if not stride == 1:
602
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
603
+
604
+ elif norm_fn == "batch":
605
+ self.norm1 = nn.BatchNorm2d(planes)
606
+ self.norm2 = nn.BatchNorm2d(planes)
607
+ if not stride == 1:
608
+ self.norm3 = nn.BatchNorm2d(planes)
609
+
610
+ elif norm_fn == "instance":
611
+ self.norm1 = nn.InstanceNorm2d(planes)
612
+ self.norm2 = nn.InstanceNorm2d(planes)
613
+ if not stride == 1:
614
+ self.norm3 = nn.InstanceNorm2d(planes)
615
+
616
+ elif norm_fn == "none":
617
+ self.norm1 = nn.Sequential()
618
+ self.norm2 = nn.Sequential()
619
+ if not stride == 1:
620
+ self.norm3 = nn.Sequential()
621
+
622
+ if stride == 1:
623
+ self.downsample = None
624
+
625
+ else:
626
+ self.downsample = nn.Sequential(
627
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
628
+ )
629
+
630
+ def forward(self, x):
631
+ y = x
632
+ y = self.relu(self.norm1(self.conv1(y)))
633
+ y = self.relu(self.norm2(self.conv2(y)))
634
+
635
+ if self.downsample is not None:
636
+ x = self.downsample(x)
637
+
638
+ return self.relu(x + y)
639
+
640
+
641
+ class BasicEncoder(nn.Module):
642
+ def __init__(self, input_dim=3, output_dim=128, stride=4):
643
+ super(BasicEncoder, self).__init__()
644
+ self.stride = stride
645
+ self.norm_fn = "instance"
646
+ self.in_planes = output_dim // 2
647
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
648
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
649
+
650
+ self.conv1 = nn.Conv2d(
651
+ input_dim,
652
+ self.in_planes,
653
+ kernel_size=7,
654
+ stride=2,
655
+ padding=3,
656
+ padding_mode="zeros",
657
+ )
658
+ self.relu1 = nn.ReLU(inplace=True)
659
+ self.layer1 = self._make_layer(output_dim // 2, stride=1)
660
+ self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
661
+ self.layer3 = self._make_layer(output_dim, stride=2)
662
+ self.layer4 = self._make_layer(output_dim, stride=2)
663
+
664
+ self.conv2 = nn.Conv2d(
665
+ output_dim * 3 + output_dim // 4,
666
+ output_dim * 2,
667
+ kernel_size=3,
668
+ padding=1,
669
+ padding_mode="zeros",
670
+ )
671
+ self.relu2 = nn.ReLU(inplace=True)
672
+ self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
673
+ for m in self.modules():
674
+ if isinstance(m, nn.Conv2d):
675
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
676
+ elif isinstance(m, (nn.InstanceNorm2d)):
677
+ if m.weight is not None:
678
+ nn.init.constant_(m.weight, 1)
679
+ if m.bias is not None:
680
+ nn.init.constant_(m.bias, 0)
681
+
682
+ def _make_layer(self, dim, stride=1):
683
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
684
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
685
+ layers = (layer1, layer2)
686
+
687
+ self.in_planes = dim
688
+ return nn.Sequential(*layers)
689
+
690
+ def forward(self, x):
691
+ _, _, H, W = x.shape
692
+
693
+ x = self.conv1(x)
694
+ x = self.norm1(x)
695
+ x = self.relu1(x)
696
+
697
+ a = self.layer1(x)
698
+ b = self.layer2(a)
699
+ c = self.layer3(b)
700
+ d = self.layer4(c)
701
+
702
+ def _bilinear_intepolate(x):
703
+ return F.interpolate(
704
+ x,
705
+ (H // self.stride, W // self.stride),
706
+ mode="bilinear",
707
+ align_corners=True,
708
+ )
709
+
710
+ a = _bilinear_intepolate(a)
711
+ b = _bilinear_intepolate(b)
712
+ c = _bilinear_intepolate(c)
713
+ d = _bilinear_intepolate(d)
714
+
715
+ x = self.conv2(torch.cat([a, b, c, d], dim=1))
716
+ x = self.norm2(x)
717
+ x = self.relu2(x)
718
+ x = self.conv3(x)
719
+ return x
720
+
721
+ class EfficientUpdateFormer(nn.Module):
722
+ """
723
+ Transformer model that updates track estimates.
724
+ """
725
+
726
+ def __init__(
727
+ self,
728
+ space_depth=6,
729
+ time_depth=6,
730
+ input_dim=320,
731
+ hidden_size=384,
732
+ num_heads=8,
733
+ output_dim=130,
734
+ mlp_ratio=4.0,
735
+ num_virtual_tracks=64,
736
+ add_space_attn=True,
737
+ linear_layer_for_vis_conf=False,
738
+ use_time_conv=False,
739
+ use_time_mixer=False,
740
+ ):
741
+ super().__init__()
742
+ self.out_channels = 2
743
+ self.num_heads = num_heads
744
+ self.hidden_size = hidden_size
745
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
746
+ if linear_layer_for_vis_conf:
747
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
748
+ self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
749
+ else:
750
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
751
+ self.num_virtual_tracks = num_virtual_tracks
752
+ self.virual_tracks = nn.Parameter(
753
+ torch.randn(1, num_virtual_tracks, 1, hidden_size)
754
+ )
755
+ self.add_space_attn = add_space_attn
756
+ self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
757
+
758
+ if use_time_conv:
759
+ self.time_blocks = nn.ModuleList(
760
+ [
761
+ CNBlock1d(hidden_size, hidden_size, dense=False)
762
+ for _ in range(time_depth)
763
+ ]
764
+ )
765
+ elif use_time_mixer:
766
+ self.time_blocks = nn.ModuleList(
767
+ [
768
+ MLPMixerBlock(
769
+ S=16,
770
+ dim=hidden_size,
771
+ depth=1,
772
+ )
773
+ for _ in range(time_depth)
774
+ ]
775
+ )
776
+ else:
777
+ self.time_blocks = nn.ModuleList(
778
+ [
779
+ AttnBlock(
780
+ hidden_size,
781
+ num_heads,
782
+ mlp_ratio=mlp_ratio,
783
+ attn_class=Attention,
784
+ )
785
+ for _ in range(time_depth)
786
+ ]
787
+ )
788
+
789
+ if add_space_attn:
790
+ self.space_virtual_blocks = nn.ModuleList(
791
+ [
792
+ AttnBlock(
793
+ hidden_size,
794
+ num_heads,
795
+ mlp_ratio=mlp_ratio,
796
+ attn_class=Attention,
797
+ )
798
+ for _ in range(space_depth)
799
+ ]
800
+ )
801
+ self.space_point2virtual_blocks = nn.ModuleList(
802
+ [
803
+ CrossAttnBlock(
804
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
805
+ )
806
+ for _ in range(space_depth)
807
+ ]
808
+ )
809
+ self.space_virtual2point_blocks = nn.ModuleList(
810
+ [
811
+ CrossAttnBlock(
812
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
813
+ )
814
+ for _ in range(space_depth)
815
+ ]
816
+ )
817
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
818
+ self.initialize_weights()
819
+
820
+ def initialize_weights(self):
821
+ def _basic_init(module):
822
+ if isinstance(module, nn.Linear):
823
+ torch.nn.init.xavier_uniform_(module.weight)
824
+ if module.bias is not None:
825
+ nn.init.constant_(module.bias, 0)
826
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
827
+ if self.linear_layer_for_vis_conf:
828
+ torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
829
+
830
+ def _trunc_init(module):
831
+ """ViT weight initialization, original timm impl (for reproducibility)"""
832
+ if isinstance(module, nn.Linear):
833
+ torch.nn.init.trunc_normal_(module.weight, std=0.02)
834
+ if module.bias is not None:
835
+ nn.init.zeros_(module.bias)
836
+
837
+ self.apply(_basic_init)
838
+
839
+ def forward(self, input_tensor, mask=None, add_space_attn=True):
840
+ tokens = self.input_transform(input_tensor)
841
+
842
+ B, _, T, _ = tokens.shape
843
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
844
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
845
+
846
+ _, N, _, _ = tokens.shape
847
+ j = 0
848
+ layers = []
849
+ for i in range(len(self.time_blocks)):
850
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
851
+ time_tokens = self.time_blocks[i](time_tokens)
852
+
853
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
854
+ if (
855
+ add_space_attn
856
+ and hasattr(self, "space_virtual_blocks")
857
+ and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
858
+ ):
859
+ space_tokens = (
860
+ tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
861
+ ) # B N T C -> (B T) N C
862
+
863
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
864
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
865
+
866
+ virtual_tokens = self.space_virtual2point_blocks[j](
867
+ virtual_tokens, point_tokens, mask=mask
868
+ )
869
+
870
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
871
+ point_tokens = self.space_point2virtual_blocks[j](
872
+ point_tokens, virtual_tokens, mask=mask
873
+ )
874
+
875
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
876
+ tokens = space_tokens.view(B, T, N, -1).permute(
877
+ 0, 2, 1, 3
878
+ ) # (B T) N C -> B N T C
879
+ j += 1
880
+ tokens = tokens[:, : N - self.num_virtual_tracks]
881
+
882
+ flow = self.flow_head(tokens)
883
+ if self.linear_layer_for_vis_conf:
884
+ vis_conf = self.vis_conf_head(tokens)
885
+ flow = torch.cat([flow, vis_conf], dim=-1)
886
+
887
+ return flow
888
+
889
+
890
+ class MMPreNormResidual(nn.Module):
891
+ def __init__(self, dim, fn):
892
+ super().__init__()
893
+ self.fn = fn
894
+ self.norm = nn.LayerNorm(dim)
895
+
896
+ def forward(self, x):
897
+ return self.fn(self.norm(x)) + x
898
+
899
+ def MMFeedForward(dim, expansion_factor=4, dropout=0., dense=nn.Linear):
900
+ return nn.Sequential(
901
+ dense(dim, dim * expansion_factor),
902
+ nn.GELU(),
903
+ nn.Dropout(dropout),
904
+ dense(dim * expansion_factor, dim),
905
+ nn.Dropout(dropout)
906
+ )
907
+
908
+ def MLPMixer(S, input_dim, dim, output_dim, depth=6, expansion_factor=4, dropout=0., do_reduce=False):
909
+ # input is coming in as B,S,C, as standard for mlp and transformer
910
+ # chan_first treats S as the channel dim, and transforms it to a new S
911
+ # chan_last treats C as the channel dim, and transforms it to a new C
912
+ chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
913
+ if do_reduce:
914
+ return nn.Sequential(
915
+ nn.Linear(input_dim, dim),
916
+ *[nn.Sequential(
917
+ MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
918
+ MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
919
+ ) for _ in range(depth)],
920
+ nn.LayerNorm(dim),
921
+ Reduce('b n c -> b c', 'mean'),
922
+ nn.Linear(dim, output_dim)
923
+ )
924
+ else:
925
+ return nn.Sequential(
926
+ nn.Linear(input_dim, dim),
927
+ *[nn.Sequential(
928
+ MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
929
+ MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
930
+ ) for _ in range(depth)],
931
+ )
932
+
933
+ def MLPMixerBlock(S, dim, depth=1, expansion_factor=4, dropout=0., do_reduce=False):
934
+ # input is coming in as B,S,C, as standard for mlp and transformer
935
+ # chan_first treats S as the channel dim, and transforms it to a new S
936
+ # chan_last treats C as the channel dim, and transforms it to a new C
937
+ chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
938
+ return nn.Sequential(
939
+ *[nn.Sequential(
940
+ MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
941
+ MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
942
+ ) for _ in range(depth)],
943
+ )
944
+
945
+
946
+ class MlpUpdateFormer(nn.Module):
947
+ """
948
+ Transformer model that updates track estimates.
949
+ """
950
+
951
+ def __init__(
952
+ self,
953
+ space_depth=6,
954
+ time_depth=6,
955
+ input_dim=320,
956
+ hidden_size=384,
957
+ num_heads=8,
958
+ output_dim=130,
959
+ mlp_ratio=4.0,
960
+ num_virtual_tracks=64,
961
+ add_space_attn=True,
962
+ linear_layer_for_vis_conf=False,
963
+ ):
964
+ super().__init__()
965
+ self.out_channels = 2
966
+ self.num_heads = num_heads
967
+ self.hidden_size = hidden_size
968
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
969
+ if linear_layer_for_vis_conf:
970
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
971
+ self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
972
+ else:
973
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
974
+ self.num_virtual_tracks = num_virtual_tracks
975
+ self.virual_tracks = nn.Parameter(
976
+ torch.randn(1, num_virtual_tracks, 1, hidden_size)
977
+ )
978
+ self.add_space_attn = add_space_attn
979
+ self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
980
+ self.time_blocks = nn.ModuleList(
981
+ [
982
+ MLPMixer(
983
+ S=16,
984
+ input_dim=hidden_size,
985
+ dim=hidden_size,
986
+ output_dim=hidden_size,
987
+ depth=1,
988
+ )
989
+ for _ in range(time_depth)
990
+ ]
991
+ )
992
+
993
+ if add_space_attn:
994
+ self.space_virtual_blocks = nn.ModuleList(
995
+ [
996
+ AttnBlock(
997
+ hidden_size,
998
+ num_heads,
999
+ mlp_ratio=mlp_ratio,
1000
+ attn_class=Attention,
1001
+ )
1002
+ for _ in range(space_depth)
1003
+ ]
1004
+ )
1005
+ self.space_point2virtual_blocks = nn.ModuleList(
1006
+ [
1007
+ CrossAttnBlock(
1008
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
1009
+ )
1010
+ for _ in range(space_depth)
1011
+ ]
1012
+ )
1013
+ self.space_virtual2point_blocks = nn.ModuleList(
1014
+ [
1015
+ CrossAttnBlock(
1016
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
1017
+ )
1018
+ for _ in range(space_depth)
1019
+ ]
1020
+ )
1021
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
1022
+ self.initialize_weights()
1023
+
1024
+ def initialize_weights(self):
1025
+ def _basic_init(module):
1026
+ if isinstance(module, nn.Linear):
1027
+ torch.nn.init.xavier_uniform_(module.weight)
1028
+ if module.bias is not None:
1029
+ nn.init.constant_(module.bias, 0)
1030
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
1031
+ if self.linear_layer_for_vis_conf:
1032
+ torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
1033
+
1034
+ def _trunc_init(module):
1035
+ """ViT weight initialization, original timm impl (for reproducibility)"""
1036
+ if isinstance(module, nn.Linear):
1037
+ torch.nn.init.trunc_normal_(module.weight, std=0.02)
1038
+ if module.bias is not None:
1039
+ nn.init.zeros_(module.bias)
1040
+
1041
+ self.apply(_basic_init)
1042
+
1043
+ def forward(self, input_tensor, mask=None, add_space_attn=True):
1044
+ tokens = self.input_transform(input_tensor)
1045
+
1046
+ B, _, T, _ = tokens.shape
1047
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
1048
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
1049
+
1050
+ _, N, _, _ = tokens.shape
1051
+ j = 0
1052
+ layers = []
1053
+ for i in range(len(self.time_blocks)):
1054
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
1055
+ time_tokens = self.time_blocks[i](time_tokens)
1056
+
1057
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
1058
+ if (
1059
+ add_space_attn
1060
+ and hasattr(self, "space_virtual_blocks")
1061
+ and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
1062
+ ):
1063
+ space_tokens = (
1064
+ tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
1065
+ ) # B N T C -> (B T) N C
1066
+
1067
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
1068
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
1069
+
1070
+ virtual_tokens = self.space_virtual2point_blocks[j](
1071
+ virtual_tokens, point_tokens, mask=mask
1072
+ )
1073
+
1074
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
1075
+ point_tokens = self.space_point2virtual_blocks[j](
1076
+ point_tokens, virtual_tokens, mask=mask
1077
+ )
1078
+
1079
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
1080
+ tokens = space_tokens.view(B, T, N, -1).permute(
1081
+ 0, 2, 1, 3
1082
+ ) # (B T) N C -> B N T C
1083
+ j += 1
1084
+ tokens = tokens[:, : N - self.num_virtual_tracks]
1085
+
1086
+ flow = self.flow_head(tokens)
1087
+ if self.linear_layer_for_vis_conf:
1088
+ vis_conf = self.vis_conf_head(tokens)
1089
+ flow = torch.cat([flow, vis_conf], dim=-1)
1090
+
1091
+ return flow
1092
+
1093
+ class BasicMotionEncoder(nn.Module):
1094
+ def __init__(self, corr_channel, dim=128, pdim=2):
1095
+ super(BasicMotionEncoder, self).__init__()
1096
+ self.pdim = pdim
1097
+ self.convc1 = nn.Conv2d(corr_channel, dim*4, 1, padding=0)
1098
+ self.convc2 = nn.Conv2d(dim*4, dim+dim//2, 3, padding=1)
1099
+ if pdim==2 or pdim==4:
1100
+ self.convf1 = nn.Conv2d(pdim, dim*2, 5, padding=2)
1101
+ self.convf2 = nn.Conv2d(dim*2, dim//2, 3, padding=1)
1102
+ self.conv = nn.Conv2d(dim*2, dim-pdim, 3, padding=1)
1103
+ else:
1104
+ self.conv = nn.Conv2d(dim+dim//2+pdim, dim, 3, padding=1)
1105
+
1106
+ def forward(self, flow, corr):
1107
+ cor = F.relu(self.convc1(corr))
1108
+ cor = F.relu(self.convc2(cor))
1109
+ if self.pdim==2 or self.pdim==4:
1110
+ flo = F.relu(self.convf1(flow))
1111
+ flo = F.relu(self.convf2(flo))
1112
+ cor_flo = torch.cat([cor, flo], dim=1)
1113
+ out = F.relu(self.conv(cor_flo))
1114
+ return torch.cat([out, flow], dim=1)
1115
+ else:
1116
+ # the flow is already encoded to something nice
1117
+ cor_flo = torch.cat([cor, flow], dim=1)
1118
+ return F.relu(self.conv(cor_flo))
1119
+ # return torch.cat([out, flow], dim=1)
1120
+
1121
+ def conv133_encoder(input_dim, dim, expansion_factor=4):
1122
+ return nn.Sequential(
1123
+ nn.Conv2d(input_dim, dim*expansion_factor, kernel_size=1),
1124
+ nn.GELU(),
1125
+ nn.Conv2d(dim*expansion_factor, dim*expansion_factor, kernel_size=3, padding=1),
1126
+ nn.GELU(),
1127
+ nn.Conv2d(dim*expansion_factor, dim, kernel_size=3, padding=1),
1128
+ )
1129
+
1130
+ class BasicUpdateBlock(nn.Module):
1131
+ def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128):
1132
+ # flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
1133
+ super(BasicUpdateBlock, self).__init__()
1134
+ self.encoder = BasicMotionEncoder(corr_channel, dim=cdim)
1135
+ self.compressor = conv1x1(2*cdim+hdim, hdim)
1136
+
1137
+ self.refine = []
1138
+ for i in range(num_blocks):
1139
+ self.refine.append(CNBlock1d(hdim, hdim))
1140
+ self.refine.append(CNBlock2d(hdim, hdim))
1141
+ self.refine = nn.ModuleList(self.refine)
1142
+
1143
+ def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True):
1144
+ BS,C,H,W = flowfeat.shape
1145
+ B = BS//S
1146
+
1147
+ # with torch.no_grad():
1148
+ motion_features = self.encoder(flow, corr)
1149
+ flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1))
1150
+
1151
+ for blk in self.refine:
1152
+ flowfeat = blk(flowfeat, S)
1153
+ return flowfeat
1154
+
1155
+ class FullUpdateBlock(nn.Module):
1156
+ def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=2, use_attn=False):
1157
+ # flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
1158
+ super(FullUpdateBlock, self).__init__()
1159
+ self.encoder = BasicMotionEncoder(corr_channel, dim=cdim, pdim=pdim)
1160
+
1161
+ # note we have hdim==cdim
1162
+ # compressor chans:
1163
+ # dim for flowfeat
1164
+ # dim for ctxfeat
1165
+ # dim for motion_features
1166
+ # pdim for flow (if p 2, like if we give sincos(relflow))
1167
+ # 2 for visconf
1168
+
1169
+ if pdim==2:
1170
+ # hdim==cdim
1171
+ # dim for flowfeat
1172
+ # dim for ctxfeat
1173
+ # dim for motion_features
1174
+ # 2 for visconf
1175
+ self.compressor = conv1x1(2*cdim+hdim+2, hdim)
1176
+ else:
1177
+ # we concatenate the flow info again, to not lose it (e.g., from the relu)
1178
+ self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim)
1179
+
1180
+ self.refine = []
1181
+ for i in range(num_blocks):
1182
+ self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn))
1183
+ self.refine.append(CNBlock2d(hdim, hdim))
1184
+ self.refine = nn.ModuleList(self.refine)
1185
+
1186
+ def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
1187
+ BS,C,H,W = flowfeat.shape
1188
+ B = BS//S
1189
+ motion_features = self.encoder(flow, corr)
1190
+ flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1))
1191
+ for blk in self.refine:
1192
+ flowfeat = blk(flowfeat, S)
1193
+ return flowfeat
1194
+
1195
+ class MixerUpdateBlock(nn.Module):
1196
+ def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128):
1197
+ # flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
1198
+ super(MixerUpdateBlock, self).__init__()
1199
+ self.encoder = BasicMotionEncoder(corr_channel, dim=cdim)
1200
+ self.compressor = conv1x1(2*cdim+hdim, hdim)
1201
+
1202
+ self.refine = []
1203
+ for i in range(num_blocks):
1204
+ self.refine.append(CNBlock1d(hdim, hdim, use_mixer=True))
1205
+ self.refine.append(CNBlock2d(hdim, hdim))
1206
+ self.refine = nn.ModuleList(self.refine)
1207
+
1208
+ def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True):
1209
+ BS,C,H,W = flowfeat.shape
1210
+ B = BS//S
1211
+
1212
+ # with torch.no_grad():
1213
+ motion_features = self.encoder(flow, corr)
1214
+ flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1))
1215
+
1216
+ for ii, blk in enumerate(self.refine):
1217
+ flowfeat = blk(flowfeat, S)
1218
+ return flowfeat
1219
+
1220
+ class FacUpdateBlock(nn.Module):
1221
+ def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=84, use_attn=False):
1222
+ super(FacUpdateBlock, self).__init__()
1223
+ self.corr_encoder = conv133_encoder(corr_channel, cdim)
1224
+ # note we have hdim==cdim
1225
+ # compressor chans:
1226
+ # dim for flowfeat
1227
+ # dim for ctxfeat
1228
+ # dim for corr
1229
+ # pdim for flow
1230
+ # 2 for visconf
1231
+ self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim)
1232
+ self.refine = []
1233
+ for i in range(num_blocks):
1234
+ self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn))
1235
+ self.refine.append(CNBlock2d(hdim, hdim))
1236
+ self.refine = nn.ModuleList(self.refine)
1237
+
1238
+ def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
1239
+ BS,C,H,W = flowfeat.shape
1240
+ B = BS//S
1241
+ corr = self.corr_encoder(corr)
1242
+ flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corr, visconf, flow], dim=1))
1243
+ for blk in self.refine:
1244
+ flowfeat = blk(flowfeat, S)
1245
+ return flowfeat
1246
+
1247
+ class CleanUpdateBlock(nn.Module):
1248
+ def __init__(self, corr_channel, num_blocks, cdim=128, hdim=256, pdim=84, use_attn=False, use_layer_scale=True):
1249
+ super(CleanUpdateBlock, self).__init__()
1250
+ self.corr_encoder = conv133_encoder(corr_channel, cdim)
1251
+ # compressor chans:
1252
+ # cdim for flowfeat
1253
+ # cdim for ctxfeat
1254
+ # cdim for corrfeat
1255
+ # pdim for flow
1256
+ # 2 for visconf
1257
+ self.compressor = conv1x1(3*cdim+pdim+2, hdim)
1258
+ self.refine = []
1259
+ for i in range(num_blocks):
1260
+ self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_layer_scale=use_layer_scale))
1261
+ self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale))
1262
+ self.refine = nn.ModuleList(self.refine)
1263
+ self.final_conv = conv1x1(hdim, cdim)
1264
+
1265
+ def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
1266
+ BS,C,H,W = flowfeat.shape
1267
+ B = BS//S
1268
+ corrfeat = self.corr_encoder(corr)
1269
+ flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corrfeat, flow, visconf], dim=1))
1270
+ for blk in self.refine:
1271
+ flowfeat = blk(flowfeat, S)
1272
+ flowfeat = self.final_conv(flowfeat)
1273
+ return flowfeat
1274
+
1275
+ class RelUpdateBlock(nn.Module):
1276
+ def __init__(self, corr_channel, num_blocks, cdim=128, hdim=128, pdim=4, use_attn=True, use_mixer=False, use_conv=False, use_convb=False, use_layer_scale=True, no_time=False, no_space=False, no_ctx=False):
1277
+ super(RelUpdateBlock, self).__init__()
1278
+ self.motion_encoder = BasicMotionEncoder(corr_channel, dim=hdim, pdim=pdim) # B,hdim,H,W
1279
+ self.no_ctx = no_ctx
1280
+ if no_ctx:
1281
+ self.compressor = conv1x1(cdim+hdim+2, hdim)
1282
+ else:
1283
+ self.compressor = conv1x1(2*cdim+hdim+2, hdim)
1284
+ self.refine = []
1285
+ for i in range(num_blocks):
1286
+ if not no_time:
1287
+ self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_mixer=use_mixer, use_conv=use_conv, use_convb=use_convb, use_layer_scale=use_layer_scale))
1288
+ if not no_space:
1289
+ self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale))
1290
+ self.refine = nn.ModuleList(self.refine)
1291
+ self.final_conv = conv1x1(hdim, cdim)
1292
+
1293
+ def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
1294
+ BS,C,H,W = flowfeat.shape
1295
+ B = BS//S
1296
+ motion_features = self.motion_encoder(flow, corr)
1297
+ if self.no_ctx:
1298
+ flowfeat = self.compressor(torch.cat([flowfeat, motion_features, visconf], dim=1))
1299
+ else:
1300
+ flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1))
1301
+ for blk in self.refine:
1302
+ flowfeat = blk(flowfeat, S)
1303
+ flowfeat = self.final_conv(flowfeat)
1304
+ return flowfeat
utils/basic.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import os
4
+ EPS = 1e-6
5
+
6
+ def sub2ind(height, width, y, x):
7
+ return y*width + x
8
+
9
+ def ind2sub(height, width, ind):
10
+ y = ind // width
11
+ x = ind % width
12
+ return y, x
13
+
14
+ def get_lr_str(lr):
15
+ lrn = "%.1e" % lr # e.g., 5.0e-04
16
+ lrn = lrn[0] + lrn[3:5] + lrn[-1] # e.g., 5e-4
17
+ return lrn
18
+
19
+ def strnum(x):
20
+ s = '%g' % x
21
+ if '.' in s:
22
+ if x < 1.0:
23
+ s = s[s.index('.'):]
24
+ s = s[:min(len(s),4)]
25
+ return s
26
+
27
+ def assert_same_shape(t1, t2):
28
+ for (x, y) in zip(list(t1.shape), list(t2.shape)):
29
+ assert(x==y)
30
+
31
+ def mkdir(path):
32
+ if not os.path.exists(path):
33
+ os.makedirs(path)
34
+
35
+ def print_stats(name, tensor):
36
+ shape = tensor.shape
37
+ tensor = tensor.detach().cpu().numpy()
38
+ print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)
39
+
40
+ def normalize_single(d):
41
+ # d is a whatever shape torch tensor
42
+ dmin = torch.min(d)
43
+ dmax = torch.max(d)
44
+ d = (d-dmin)/(EPS+(dmax-dmin))
45
+ return d
46
+
47
+ def normalize(d):
48
+ # d is B x whatever. normalize within each element of the batch
49
+ out = torch.zeros(d.size(), dtype=d.dtype, device=d.device)
50
+ B = list(d.size())[0]
51
+ for b in list(range(B)):
52
+ out[b] = normalize_single(d[b])
53
+ return out
54
+
55
+ def meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False):
56
+ # returns a meshgrid sized B x Y x X
57
+
58
+ grid_y = torch.linspace(0.0, Y-1, Y, device=torch.device(device))
59
+ grid_y = torch.reshape(grid_y, [1, Y, 1])
60
+ grid_y = grid_y.repeat(B, 1, X)
61
+
62
+ grid_x = torch.linspace(0.0, X-1, X, device=torch.device(device))
63
+ grid_x = torch.reshape(grid_x, [1, 1, X])
64
+ grid_x = grid_x.repeat(B, Y, 1)
65
+
66
+ if norm:
67
+ grid_y, grid_x = normalize_grid2d(
68
+ grid_y, grid_x, Y, X)
69
+
70
+ if stack:
71
+ # note we stack in xy order
72
+ # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
73
+ if on_chans:
74
+ grid = torch.stack([grid_x, grid_y], dim=1)
75
+ else:
76
+ grid = torch.stack([grid_x, grid_y], dim=-1)
77
+ return grid
78
+ else:
79
+ return grid_y, grid_x
80
+
81
+ def gridcloud2d(B, Y, X, norm=False, device='cuda'):
82
+ # we want to sample for each location in the grid
83
+ grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device)
84
+ x = torch.reshape(grid_x, [B, -1])
85
+ y = torch.reshape(grid_y, [B, -1])
86
+ # these are B x N
87
+ xy = torch.stack([x, y], dim=2)
88
+ # this is B x N x 2
89
+ return xy
90
+
91
+ def reduce_masked_mean(x, mask, dim=None, keepdim=False, broadcast=False):
92
+ # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
93
+ # returns shape-1
94
+ # axis can be a list of axes
95
+ if not broadcast:
96
+ for (a,b) in zip(x.size(), mask.size()):
97
+ if not a==b:
98
+ print('some shape mismatch:', x.shape, mask.shape)
99
+ assert(a==b) # some shape mismatch!
100
+ # assert(x.size() == mask.size())
101
+ prod = x*mask
102
+ if dim is None:
103
+ numer = torch.sum(prod)
104
+ denom = EPS+torch.sum(mask)
105
+ else:
106
+ numer = torch.sum(prod, dim=dim, keepdim=keepdim)
107
+ denom = EPS+torch.sum(mask, dim=dim, keepdim=keepdim)
108
+ mean = numer/denom
109
+ return mean
110
+
111
+ def reduce_masked_median(x, mask, keep_batch=False):
112
+ # x and mask are the same shape
113
+ assert(x.size() == mask.size())
114
+ device = x.device
115
+
116
+ B = list(x.shape)[0]
117
+ x = x.detach().cpu().numpy()
118
+ mask = mask.detach().cpu().numpy()
119
+
120
+ if keep_batch:
121
+ x = np.reshape(x, [B, -1])
122
+ mask = np.reshape(mask, [B, -1])
123
+ meds = np.zeros([B], np.float32)
124
+ for b in list(range(B)):
125
+ xb = x[b]
126
+ mb = mask[b]
127
+ if np.sum(mb) > 0:
128
+ xb = xb[mb > 0]
129
+ meds[b] = np.median(xb)
130
+ else:
131
+ meds[b] = np.nan
132
+ meds = torch.from_numpy(meds).to(device)
133
+ return meds.float()
134
+ else:
135
+ x = np.reshape(x, [-1])
136
+ mask = np.reshape(mask, [-1])
137
+ if np.sum(mask) > 0:
138
+ x = x[mask > 0]
139
+ med = np.median(x)
140
+ else:
141
+ med = np.nan
142
+ med = np.array([med], np.float32)
143
+ med = torch.from_numpy(med).to(device)
144
+ return med.float()
utils/data.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import dataclasses
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ from typing import Any, Optional, Dict
6
+
7
+
8
+ @dataclass(eq=False)
9
+ class VideoData:
10
+ """
11
+ Dataclass for storing video tracks data.
12
+ """
13
+
14
+ video: torch.Tensor # B,S,C,H,W
15
+ trajs: torch.Tensor # B,S,N,2
16
+ visibs: torch.Tensor # B,S,N
17
+ valids: Optional[torch.Tensor] = None # B,S,N
18
+ seq_name: Optional[str] = None
19
+ dname: Optional[str] = None
20
+ aug_video: Optional[torch.Tensor] = None
21
+
22
+
23
+ def collate_fn(batch):
24
+ """
25
+ Collate function for video tracks data.
26
+ """
27
+ video = torch.stack([b.video for b in batch], dim=0)
28
+ trajs = torch.stack([b.trajs for b in batch], dim=0)
29
+ visibs = torch.stack([b.visibs for b in batch], dim=0)
30
+ seq_name = [b.seq_name for b in batch]
31
+ dname = [b.dname for b in batch]
32
+
33
+ return VideoData(
34
+ video=video,
35
+ trajs=trajs,
36
+ visibs=visibs,
37
+ seq_name=seq_name,
38
+ dname=dname,
39
+ )
40
+
41
+
42
+ def collate_fn_train(batch):
43
+ """
44
+ Collate function for video tracks data during training.
45
+ """
46
+ gotit = [gotit for _, gotit in batch]
47
+ video = torch.stack([b.video for b, _ in batch], dim=0)
48
+ trajs = torch.stack([b.trajs for b, _ in batch], dim=0)
49
+ visibs = torch.stack([b.visibs for b, _ in batch], dim=0)
50
+ valids = torch.stack([b.valids for b, _ in batch], dim=0)
51
+ seq_name = [b.seq_name for b, _ in batch]
52
+ dname = [b.dname for b, _ in batch]
53
+
54
+ return (
55
+ VideoData(
56
+ video=video,
57
+ trajs=trajs,
58
+ visibs=visibs,
59
+ valids=valids,
60
+ seq_name=seq_name,
61
+ dname=dname,
62
+ ),
63
+ gotit,
64
+ )
65
+
66
+
67
+ def try_to_cuda(t: Any) -> Any:
68
+ """
69
+ Try to move the input variable `t` to a cuda device.
70
+
71
+ Args:
72
+ t: Input.
73
+
74
+ Returns:
75
+ t_cuda: `t` moved to a cuda device, if supported.
76
+ """
77
+ try:
78
+ t = t.float().cuda()
79
+ except AttributeError:
80
+ pass
81
+ return t
82
+
83
+
84
+ def dataclass_to_cuda_(obj):
85
+ """
86
+ Move all contents of a dataclass to cuda inplace if supported.
87
+
88
+ Args:
89
+ batch: Input dataclass.
90
+
91
+ Returns:
92
+ batch_cuda: `batch` moved to a cuda device, if supported.
93
+ """
94
+ for f in dataclasses.fields(obj):
95
+ setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
96
+ return obj
utils/improc.py ADDED
@@ -0,0 +1,1103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import utils.basic
4
+ import utils.py
5
+ from sklearn.decomposition import PCA
6
+ from matplotlib import cm
7
+ import matplotlib.pyplot as plt
8
+ import cv2
9
+ import torch.nn.functional as F
10
+ EPS = 1e-6
11
+
12
+ from skimage.color import (
13
+ rgb2lab, rgb2yuv, rgb2ycbcr, lab2rgb, yuv2rgb, ycbcr2rgb,
14
+ rgb2hsv, hsv2rgb, rgb2xyz, xyz2rgb, rgb2hed, hed2rgb)
15
+
16
+ def _convert(input_, type_):
17
+ return {
18
+ 'float': input_.float(),
19
+ 'double': input_.double(),
20
+ }.get(type_, input_)
21
+
22
+ def _generic_transform_sk_3d(transform, in_type='', out_type=''):
23
+ def apply_transform_individual(input_):
24
+ device = input_.device
25
+ input_ = input_.cpu()
26
+ input_ = _convert(input_, in_type)
27
+
28
+ input_ = input_.permute(1, 2, 0).detach().numpy()
29
+ transformed = transform(input_)
30
+ output = torch.from_numpy(transformed).float().permute(2, 0, 1)
31
+ output = _convert(output, out_type)
32
+ return output.to(device)
33
+
34
+ def apply_transform(input_):
35
+ to_stack = []
36
+ for image in input_:
37
+ to_stack.append(apply_transform_individual(image))
38
+ return torch.stack(to_stack)
39
+ return apply_transform
40
+
41
+ hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb)
42
+
43
+ def flow2color(flow, clip=0.0):
44
+ B, C, H, W = list(flow.size())
45
+ assert(C==2)
46
+ flow = flow[0:1].detach()
47
+ if clip==0:
48
+ clip = torch.max(torch.abs(flow)).item()
49
+ flow = torch.clamp(flow, -clip, clip)/clip
50
+ radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) # B,1,H,W
51
+ radius_clipped = torch.clamp(radius, 0.0, 1.0)
52
+ angle = torch.atan2(-flow[:, 1:2], -flow[:, 0:1]) / np.pi # B,1,H,W
53
+ hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0)
54
+ saturation = torch.ones_like(hue) * 0.75
55
+ value = radius_clipped
56
+ hsv = torch.cat([hue, saturation, value], dim=1) # B,3,H,W
57
+ flow = hsv_to_rgb(hsv)
58
+ flow = (flow*255.0).type(torch.ByteTensor)
59
+ return flow
60
+
61
+ COLORMAP_FILE = "./utils/bremm.png"
62
+ class ColorMap2d:
63
+ def __init__(self, filename=None):
64
+ self._colormap_file = filename or COLORMAP_FILE
65
+ self._img = (plt.imread(self._colormap_file)*255).astype(np.uint8)
66
+
67
+ self._height = self._img.shape[0]
68
+ self._width = self._img.shape[1]
69
+
70
+ def __call__(self, X):
71
+ assert len(X.shape) == 2
72
+ output = np.zeros((X.shape[0], 3), dtype=np.uint8)
73
+ for i in range(X.shape[0]):
74
+ x, y = X[i, :]
75
+ xp = int((self._width-1) * x)
76
+ yp = int((self._height-1) * y)
77
+ xp = np.clip(xp, 0, self._width-1)
78
+ yp = np.clip(yp, 0, self._height-1)
79
+ output[i, :] = self._img[yp, xp]
80
+ return output
81
+
82
+ def get_2d_colors(xys, H, W):
83
+ N,D = xys.shape
84
+ assert(D==2)
85
+ bremm = ColorMap2d()
86
+ xys[:,0] /= float(W-1)
87
+ xys[:,1] /= float(H-1)
88
+ colors = bremm(xys)
89
+ # print('colors', colors)
90
+ # colors = (colors[0]*255).astype(np.uint8)
91
+ # colors = (int(colors[0]),int(colors[1]),int(colors[2]))
92
+ return colors
93
+
94
+
95
+ def get_n_colors(N, sequential=False):
96
+ label_colors = []
97
+ for ii in range(N):
98
+ if sequential:
99
+ rgb = cm.winter(ii/(N-1))
100
+ rgb = (np.array(rgb) * 255).astype(np.uint8)[:3]
101
+ else:
102
+ rgb = np.zeros(3)
103
+ while np.sum(rgb) < 128: # ensure min brightness
104
+ rgb = np.random.randint(0,256,3)
105
+ label_colors.append(rgb)
106
+ return label_colors
107
+
108
+ def pca_embed(emb, keep, valid=None):
109
+ # helper function for reduce_emb
110
+ # emb is B,C,H,W
111
+ # keep is the number of principal components to keep
112
+ emb = emb + EPS
113
+ emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C
114
+
115
+ if valid:
116
+ valid = valid.cpu().detach().numpy().reshape((H*W))
117
+
118
+ emb_reduced = list()
119
+
120
+ B, H, W, C = np.shape(emb)
121
+ for img in emb:
122
+ if np.isnan(img).any():
123
+ emb_reduced.append(np.zeros([H, W, keep]))
124
+ continue
125
+
126
+ pixels_kd = np.reshape(img, (H*W, C))
127
+
128
+ if valid:
129
+ pixels_kd_pca = pixels_kd[valid]
130
+ else:
131
+ pixels_kd_pca = pixels_kd
132
+
133
+ P = PCA(keep)
134
+ P.fit(pixels_kd_pca)
135
+
136
+ if valid:
137
+ pixels3d = P.transform(pixels_kd)*valid
138
+ else:
139
+ pixels3d = P.transform(pixels_kd)
140
+
141
+ out_img = np.reshape(pixels3d, [H,W,keep]).astype(np.float32)
142
+ if np.isnan(out_img).any():
143
+ emb_reduced.append(np.zeros([H, W, keep]))
144
+ continue
145
+
146
+ emb_reduced.append(out_img)
147
+
148
+ emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32)
149
+
150
+ return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2)
151
+
152
+ def pca_embed_together(emb, keep):
153
+ # emb is B,C,H,W
154
+ # keep is the number of principal components to keep
155
+ emb = emb + EPS
156
+ emb = emb.permute(0, 2, 3, 1).cpu().detach().float().numpy() #this is B x H x W x C
157
+
158
+ B, H, W, C = np.shape(emb)
159
+ if np.isnan(emb).any():
160
+ return torch.zeros(B, keep, H, W)
161
+
162
+ pixelskd = np.reshape(emb, (B*H*W, C))
163
+ P = PCA(keep)
164
+ P.fit(pixelskd)
165
+ pixels3d = P.transform(pixelskd)
166
+ out_img = np.reshape(pixels3d, [B,H,W,keep]).astype(np.float32)
167
+
168
+ if np.isnan(out_img).any():
169
+ return torch.zeros(B, keep, H, W)
170
+
171
+ return torch.from_numpy(out_img).permute(0, 3, 1, 2)
172
+
173
+ def reduce_emb(emb, valid=None, inbound=None, together=False):
174
+ S, C, H, W = list(emb.size())
175
+ keep = 4
176
+
177
+ if together:
178
+ reduced_emb = pca_embed_together(emb, keep)
179
+ else:
180
+ reduced_emb = pca_embed(emb, keep, valid) #not im
181
+
182
+ reduced_emb = reduced_emb[:,1:]
183
+ reduced_emb = utils.basic.normalize(reduced_emb) - 0.5
184
+ if inbound is not None:
185
+ emb_inbound = emb*inbound
186
+ else:
187
+ emb_inbound = None
188
+
189
+ return reduced_emb, emb_inbound
190
+
191
+ def get_feat_pca(feat, valid=None):
192
+ B, C, D, W = list(feat.size())
193
+ pca, _ = reduce_emb(feat, valid=valid,inbound=None, together=True)
194
+ return pca
195
+
196
+ def gif_and_tile(ims, just_gif=False):
197
+ S = len(ims)
198
+ # each im is B x H x W x C
199
+ # i want a gif in the left, and the tiled frames on the right
200
+ # for the gif tool, this means making a B x S x H x W tensor
201
+ # where the leftmost part is sequential and the rest is tiled
202
+ gif = torch.stack(ims, dim=1)
203
+ if just_gif:
204
+ return gif
205
+ til = torch.cat(ims, dim=2)
206
+ til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1)
207
+ im = torch.cat([gif, til], dim=3)
208
+ return im
209
+
210
+ def preprocess_color(x):
211
+ if isinstance(x, np.ndarray):
212
+ return x.astype(np.float32) * 1./255 - 0.5
213
+ else:
214
+ return x.float() * 1./255 - 0.5
215
+
216
+ def back2color(i, blacken_zeros=False):
217
+ if blacken_zeros:
218
+ const = torch.tensor([-0.5])
219
+ i = torch.where(i==0.0, const.cuda() if i.is_cuda else const, i)
220
+ return back2color(i)
221
+ else:
222
+ return ((i+0.5)*255).type(torch.ByteTensor)
223
+
224
+ def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20, shadow=True):
225
+
226
+ rgb = vis.detach().cpu().numpy()[0]
227
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
228
+ rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
229
+ color = (255, 255, 255)
230
+ # print('putting frame id', frame_id)
231
+
232
+ frame_str = utils.basic.strnum(frame_id)
233
+
234
+ text_color_bg = (0,0,0)
235
+ font = cv2.FONT_HERSHEY_SIMPLEX
236
+ text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
237
+ text_w, text_h = text_size
238
+ if shadow:
239
+ cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
240
+
241
+ cv2.putText(
242
+ rgb,
243
+ frame_str,
244
+ (left, top), # from left, from top
245
+ font,
246
+ scale, # font scale (float)
247
+ color,
248
+ 1) # font thickness (int)
249
+ rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
250
+ vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
251
+ return vis
252
+
253
+ def draw_frame_str_on_vis(vis, frame_str, scale=0.5, left=5, top=40, shadow=True):
254
+
255
+ rgb = vis.detach().cpu().numpy()[0]
256
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
257
+ rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
258
+ color = (255, 255, 255)
259
+
260
+ text_color_bg = (0,0,0)
261
+ font = cv2.FONT_HERSHEY_SIMPLEX
262
+ text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
263
+ text_w, text_h = text_size
264
+ if shadow:
265
+ cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
266
+
267
+ cv2.putText(
268
+ rgb,
269
+ frame_str,
270
+ (left, top), # from left, from top
271
+ font,
272
+ scale, # font scale (float)
273
+ color,
274
+ 1) # font thickness (int)
275
+ rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
276
+ vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
277
+ return vis
278
+
279
+ class Summ_writer(object):
280
+ def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False):
281
+ self.writer = writer
282
+ self.global_step = global_step
283
+ self.log_freq = log_freq
284
+ self.scalar_freq = scalar_freq
285
+ self.fps = fps
286
+ self.just_gif = just_gif
287
+ self.maxwidth = 10000
288
+ self.save_this = (self.global_step % self.log_freq == 0)
289
+ self.scalar_freq = max(scalar_freq,1)
290
+ self.save_scalar = (self.global_step % self.scalar_freq == 0)
291
+ if self.save_this:
292
+ self.save_scalar = True
293
+
294
+ def summ_gif(self, name, tensor, blacken_zeros=False):
295
+ # tensor should be in B x S x C x H x W
296
+
297
+ assert tensor.dtype in {torch.uint8,torch.float32}
298
+ shape = list(tensor.shape)
299
+
300
+ if tensor.dtype == torch.float32:
301
+ tensor = back2color(tensor, blacken_zeros=blacken_zeros)
302
+
303
+ video_to_write = tensor[0:1]
304
+
305
+ S = video_to_write.shape[1]
306
+ if S==1:
307
+ # video_to_write is 1 x 1 x C x H x W
308
+ self.writer.add_image(name, video_to_write[0,0], global_step=self.global_step)
309
+ else:
310
+ self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step)
311
+
312
+ return video_to_write
313
+
314
+ def summ_rgbs(self, name, ims, frame_ids=None, frame_strs=None, blacken_zeros=False, only_return=False):
315
+ if self.save_this:
316
+
317
+ ims = gif_and_tile(ims, just_gif=self.just_gif)
318
+ vis = ims
319
+
320
+ assert vis.dtype in {torch.uint8,torch.float32}
321
+
322
+ if vis.dtype == torch.float32:
323
+ vis = back2color(vis, blacken_zeros)
324
+
325
+ B, S, C, H, W = list(vis.shape)
326
+
327
+ if frame_ids is not None:
328
+ assert(len(frame_ids)==S)
329
+ for s in range(S):
330
+ vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
331
+
332
+ if frame_strs is not None:
333
+ assert(len(frame_strs)==S)
334
+ for s in range(S):
335
+ vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s])
336
+
337
+ if int(W) > self.maxwidth:
338
+ vis = vis[:,:,:,:self.maxwidth]
339
+
340
+ if only_return:
341
+ return vis
342
+ else:
343
+ return self.summ_gif(name, vis, blacken_zeros)
344
+
345
+ def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, frame_str=None, only_return=False, halfres=False, shadow=True):
346
+ if self.save_this:
347
+ assert ims.dtype in {torch.uint8,torch.float32}
348
+
349
+ if ims.dtype == torch.float32:
350
+ ims = back2color(ims, blacken_zeros)
351
+
352
+ #ims is B x C x H x W
353
+ vis = ims[0:1] # just the first one
354
+ B, C, H, W = list(vis.shape)
355
+
356
+ if halfres:
357
+ vis = F.interpolate(vis, scale_factor=0.5)
358
+
359
+ if frame_id is not None:
360
+ vis = draw_frame_id_on_vis(vis, frame_id, shadow=shadow)
361
+
362
+ if frame_str is not None:
363
+ vis = draw_frame_str_on_vis(vis, frame_str, shadow=shadow)
364
+
365
+ if int(W) > self.maxwidth:
366
+ vis = vis[:,:,:,:self.maxwidth]
367
+
368
+ if only_return:
369
+ return vis
370
+ else:
371
+ return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros)
372
+
373
+ def flow2color(self, flow, clip=0.0):
374
+ B, C, H, W = list(flow.size())
375
+ assert(C==2)
376
+ flow = flow[0:1].detach()
377
+
378
+ if False:
379
+ flow = flow[0].detach().cpu().permute(1,2,0).numpy() # H,W,2
380
+ if clip > 0:
381
+ clip_flow = clip
382
+ else:
383
+ clip_flow = None
384
+ im = utils.py.flow_to_image(flow, clip_flow=clip_flow, convert_to_bgr=True)
385
+ # im = utils.py.flow_to_image(flow, convert_to_bgr=True)
386
+ im = torch.from_numpy(im).permute(2,0,1).unsqueeze(0).byte() # 1,3,H,W
387
+ im = torch.flip(im, dims=[1]).clone() # BGR
388
+
389
+ # # i prefer black bkg
390
+ # white_pixels = (im == 255).all(dim=1, keepdim=True)
391
+ # im[white_pixels.expand(-1, 3, -1, -1)] = 0
392
+
393
+ return im
394
+
395
+ # flow_abs = torch.abs(flow)
396
+ # flow_mean = flow_abs.mean(dim=[1,2,3])
397
+ # flow_std = flow_abs.std(dim=[1,2,3])
398
+ if clip==0:
399
+ clip = torch.max(torch.abs(flow)).item()
400
+
401
+ # if clip:
402
+ flow = torch.clamp(flow, -clip, clip)/clip
403
+ # else:
404
+ # # # Apply some kind of normalization. Divide by the perceived maximum (mean + std*2)
405
+ # # flow_max = flow_mean + flow_std*2 + 1e-10
406
+ # # for b in range(B):
407
+ # # flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1)
408
+
409
+ # flow_max = torch.max(flow_abs[b])
410
+ # for b in range(B):
411
+ # flow[b] = flow[b].clamp(-flow_max.item(), flow_max.item()) / flow_max[b].clamp(min=1)
412
+
413
+
414
+ radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) #B x 1 x H x W
415
+ radius_clipped = torch.clamp(radius, 0.0, 1.0)
416
+
417
+ angle = torch.atan2(-flow[:, 1:2], -flow[:, 0:1]) / np.pi # B x 1 x H x W
418
+
419
+ hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0)
420
+ # hue = torch.mod(angle / (2 * np.pi) + 1.0, 1.0)
421
+
422
+ saturation = torch.ones_like(hue) * 0.75
423
+ value = radius_clipped
424
+ hsv = torch.cat([hue, saturation, value], dim=1) #B x 3 x H x W
425
+
426
+ #flow = tf.image.hsv_to_rgb(hsv)
427
+ flow = hsv_to_rgb(hsv)
428
+ flow = (flow*255.0).type(torch.ByteTensor)
429
+ # flow = torch.flip(flow, dims=[1]).clone() # BGR
430
+ return flow
431
+
432
+ def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None, frame_str=None, shadow=True):
433
+ # flow is B x C x D x W
434
+ if self.save_this:
435
+ return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id, frame_str=frame_str, shadow=shadow)
436
+ else:
437
+ return None
438
+
439
+ def summ_oneds(self, name, ims, frame_ids=None, frame_strs=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False):
440
+ if self.save_this:
441
+ if bev:
442
+ B, C, H, _, W = list(ims[0].shape)
443
+ if reduce_max:
444
+ ims = [torch.max(im, dim=3)[0] for im in ims]
445
+ else:
446
+ ims = [torch.mean(im, dim=3) for im in ims]
447
+ elif fro:
448
+ B, C, _, H, W = list(ims[0].shape)
449
+ if reduce_max:
450
+ ims = [torch.max(im, dim=2)[0] for im in ims]
451
+ else:
452
+ ims = [torch.mean(im, dim=2) for im in ims]
453
+
454
+
455
+ if len(ims) != 1: # sequence
456
+ im = gif_and_tile(ims, just_gif=self.just_gif)
457
+ else:
458
+ im = torch.stack(ims, dim=1) # single frame
459
+
460
+ B, S, C, H, W = list(im.shape)
461
+
462
+ if logvis and max_val:
463
+ max_val = np.log(max_val)
464
+ im = torch.log(torch.clamp(im, 0)+1.0)
465
+ im = torch.clamp(im, 0, max_val)
466
+ im = im/max_val
467
+ norm = False
468
+ elif max_val:
469
+ im = torch.clamp(im, 0, max_val)
470
+ im = im/max_val
471
+ norm = False
472
+
473
+ if norm:
474
+ # normalize before oned2inferno,
475
+ # so that the ranges are similar within B across S
476
+ im = utils.basic.normalize(im)
477
+
478
+ im = im.view(B*S, C, H, W)
479
+ vis = oned2inferno(im, norm=norm, do_colorize=do_colorize)
480
+ vis = vis.view(B, S, 3, H, W)
481
+
482
+ if frame_ids is not None:
483
+ assert(len(frame_ids)==S)
484
+ for s in range(S):
485
+ vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
486
+
487
+ if frame_strs is not None:
488
+ assert(len(frame_strs)==S)
489
+ for s in range(S):
490
+ vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s])
491
+
492
+ if W > self.maxwidth:
493
+ vis = vis[...,:self.maxwidth]
494
+
495
+ if only_return:
496
+ return vis
497
+ else:
498
+ self.summ_gif(name, vis)
499
+
500
+ def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, frame_str=None, only_return=False, shadow=True):
501
+ if self.save_this:
502
+
503
+ if bev:
504
+ B, C, H, _, W = list(im.shape)
505
+ if max_along_y:
506
+ im = torch.max(im, dim=3)[0]
507
+ else:
508
+ im = torch.mean(im, dim=3)
509
+ elif fro:
510
+ B, C, _, H, W = list(im.shape)
511
+ if max_along_y:
512
+ im = torch.max(im, dim=2)[0]
513
+ else:
514
+ im = torch.mean(im, dim=2)
515
+ else:
516
+ B, C, H, W = list(im.shape)
517
+
518
+ im = im[0:1] # just the first one
519
+ assert(C==1)
520
+
521
+ if logvis and max_val:
522
+ max_val = np.log(max_val)
523
+ im = torch.log(im)
524
+ im = torch.clamp(im, 0, max_val)
525
+ im = im/max_val
526
+ norm = False
527
+ elif max_val:
528
+ im = torch.clamp(im, 0, max_val)/max_val
529
+ norm = False
530
+
531
+ vis = oned2inferno(im, norm=norm)
532
+ if W > self.maxwidth:
533
+ vis = vis[...,:self.maxwidth]
534
+ return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, frame_str=frame_str, only_return=only_return, shadow=shadow)
535
+
536
+
537
+ def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None, frame_strs=None):
538
+ if self.save_this:
539
+ if valids is not None:
540
+ valids = torch.stack(valids, dim=1)
541
+
542
+ feats = torch.stack(feats, dim=1)
543
+ # feats leads with B x S x C
544
+
545
+ if feats.ndim==6:
546
+
547
+ # feats is B x S x C x D x H x W
548
+ if fro:
549
+ reduce_dim = 3
550
+ else:
551
+ reduce_dim = 4
552
+
553
+ if valids is None:
554
+ feats = torch.mean(feats, dim=reduce_dim)
555
+ else:
556
+ valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1)
557
+ feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim)
558
+
559
+ B, S, C, D, W = list(feats.size())
560
+
561
+ if not pca:
562
+ # feats leads with B x S x C
563
+ feats = torch.mean(torch.abs(feats), dim=2, keepdims=True)
564
+ # feats leads with B x S x 1
565
+ feats = torch.unbind(feats, dim=1)
566
+ return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
567
+
568
+ else:
569
+ __p = lambda x: utils.basic.pack_seqdim(x, B)
570
+ __u = lambda x: utils.basic.unpack_seqdim(x, B)
571
+
572
+ feats_ = __p(feats)
573
+
574
+ if valids is None:
575
+ feats_pca_ = get_feat_pca(feats_)
576
+ else:
577
+ valids_ = __p(valids)
578
+ feats_pca_ = get_feat_pca(feats_, valids)
579
+
580
+ feats_pca = __u(feats_pca_)
581
+
582
+ return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
583
+
584
+ def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None, frame_str=None):
585
+ if self.save_this:
586
+ if feat.ndim==5: # B x C x D x H x W
587
+
588
+ if bev:
589
+ reduce_axis = 3
590
+ elif fro:
591
+ reduce_axis = 2
592
+ else:
593
+ # default to bev
594
+ reduce_axis = 3
595
+
596
+ if valid is None:
597
+ feat = torch.mean(feat, dim=reduce_axis)
598
+ else:
599
+ valid = valid.repeat(1, feat.size()[1], 1, 1, 1)
600
+ feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis)
601
+
602
+ B, C, D, W = list(feat.shape)
603
+
604
+ if not pca:
605
+ feat = torch.mean(torch.abs(feat), dim=1, keepdims=True)
606
+ # feat is B x 1 x D x W
607
+ return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
608
+ else:
609
+ feat_pca = get_feat_pca(feat, valid)
610
+ return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
611
+
612
+ def summ_scalar(self, name, value):
613
+ if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ('torch' in value.type()):
614
+ value = value.detach().cpu().numpy()
615
+ if not np.isnan(value):
616
+ if (self.log_freq == 1):
617
+ self.writer.add_scalar(name, value, global_step=self.global_step)
618
+ elif self.save_this or self.save_scalar:
619
+ self.writer.add_scalar(name, value, global_step=self.global_step)
620
+
621
+ def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap='coolwarm', vals=None, linewidth=1, max_show=1024):
622
+ # trajs is B, S, N, 2
623
+ # rgbs is B, S, C, H, W
624
+ B, S, C, H, W = rgbs.shape
625
+ B, S2, N, D = trajs.shape
626
+ assert(S==S2)
627
+
628
+
629
+ rgbs = rgbs[0] # S, C, H, W
630
+ trajs = trajs[0] # S, N, 2
631
+ if valids is None:
632
+ valids = torch.ones_like(trajs[:,:,0]) # S, N
633
+ else:
634
+ valids = valids[0]
635
+
636
+ if visibs is None:
637
+ visibs = torch.ones_like(trajs[:,:,0]) # S, N
638
+ else:
639
+ visibs = visibs[0]
640
+
641
+ if vals is not None:
642
+ vals = vals[0] # N
643
+ # print('vals', vals.shape)
644
+
645
+ if N > max_show:
646
+ inds = np.random.choice(N, max_show)
647
+ trajs = trajs[:,inds]
648
+ valids = valids[:,inds]
649
+ visibs = visibs[:,inds]
650
+ if vals is not None:
651
+ vals = vals[inds]
652
+ N = trajs.shape[1]
653
+
654
+ trajs = trajs.clamp(-16, W+16)
655
+
656
+ rgbs_color = []
657
+ for rgb in rgbs:
658
+ rgb = back2color(rgb).detach().cpu().numpy()
659
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
660
+ rgbs_color.append(rgb) # each element 3 x H x W
661
+
662
+ for i in range(min(N, max_show)):
663
+ if cmap=='onediff' and i==0:
664
+ cmap_ = 'spring'
665
+ elif cmap=='onediff':
666
+ cmap_ = 'winter'
667
+ else:
668
+ cmap_ = cmap
669
+ traj = trajs[:,i].long().detach().cpu().numpy() # S, 2
670
+ valid = valids[:,i].long().detach().cpu().numpy() # S
671
+
672
+ # print('traj', traj.shape)
673
+ # print('valid', valid.shape)
674
+
675
+ if vals is not None:
676
+ # val = vals[:,i].float().detach().cpu().numpy() # []
677
+ val = vals[i].float().detach().cpu().numpy() # []
678
+ # print('val', val.shape)
679
+ else:
680
+ val = None
681
+
682
+ for t in range(S):
683
+ if valid[t]:
684
+ rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj[:t+1], S=S, show_dots=show_dots, cmap=cmap_, val=val, linewidth=linewidth)
685
+
686
+ for i in range(min(N, max_show)):
687
+ if cmap=='onediff' and i==0:
688
+ cmap_ = 'spring'
689
+ elif cmap=='onediff':
690
+ cmap_ = 'winter'
691
+ else:
692
+ cmap_ = cmap
693
+ traj = trajs[:,i] # S,2
694
+ vis = visibs[:,i].round() # S
695
+ valid = valids[:,i] # S
696
+ rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
697
+
698
+ rgbs = []
699
+ for rgb in rgbs_color:
700
+ rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
701
+ rgbs.append(preprocess_color(rgb))
702
+
703
+ return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
704
+
705
+ def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap=None, linewidth=1, max_show=1024):
706
+ # trajs is B, S, N, 2
707
+ # rgbs is B, S, C, H, W
708
+ B, S, C, H, W = rgbs.shape
709
+ B, S2, N, D = trajs.shape
710
+ assert(S==S2)
711
+
712
+ rgbs = rgbs[0] # S, C, H, W
713
+ trajs = trajs[0] # S, N, 2
714
+ visibles = visibles[0] # S, N
715
+ if valids is None:
716
+ valids = torch.ones_like(trajs[:,:,0]) # S, N
717
+ else:
718
+ valids = valids[0]
719
+
720
+ rgbs_color = []
721
+ for rgb in rgbs:
722
+ rgb = back2color(rgb).detach().cpu().numpy()
723
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
724
+ rgbs_color.append(rgb) # each element 3 x H x W
725
+
726
+ trajs = trajs.long().detach().cpu().numpy() # S, N, 2
727
+ visibles = visibles.float().detach().cpu().numpy() # S, N
728
+ valids = valids.long().detach().cpu().numpy() # S, N
729
+
730
+ for i in range(min(N, max_show)):
731
+ if cmap=='onediff' and i==0:
732
+ cmap_ = 'spring'
733
+ elif cmap=='onediff':
734
+ cmap_ = 'winter'
735
+ else:
736
+ cmap_ = cmap
737
+ traj = trajs[:,i] # S,2
738
+ vis = visibles[:,i] # S
739
+ valid = valids[:,i] # S
740
+ rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
741
+
742
+ for i in range(min(N, max_show)):
743
+ if cmap=='onediff' and i==0:
744
+ cmap_ = 'spring'
745
+ elif cmap=='onediff':
746
+ cmap_ = 'winter'
747
+ else:
748
+ cmap_ = cmap
749
+ traj = trajs[:,i] # S,2
750
+ vis = visibles[:,i] # S
751
+ valid = valids[:,i] # S
752
+ rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth)
753
+
754
+ rgbs = []
755
+ for rgb in rgbs_color:
756
+ rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
757
+ rgbs.append(preprocess_color(rgb))
758
+
759
+ return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
760
+
761
+ def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=True, show_lines=True, frame_id=None, frame_str=None, only_return=False, cmap='coolwarm', linewidth=1, max_show=1024):
762
+ # trajs is B, S, N, 2
763
+ # rgb is B, C, H, W
764
+ B, C, H, W = rgb.shape
765
+ B, S, N, D = trajs.shape
766
+
767
+ rgb = rgb[0] # S, C, H, W
768
+ trajs = trajs[0] # S, N, 2
769
+
770
+ if valids is None:
771
+ valids = torch.ones_like(trajs[:,:,0])
772
+ else:
773
+ valids = valids[0]
774
+
775
+ rgb_color = back2color(rgb).detach().cpu().numpy()
776
+ rgb_color = np.transpose(rgb_color, [1, 2, 0]) # put channels last
777
+
778
+ # using maxdist will dampen the colors for short motions
779
+ # norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0])**2, dim=1)) # N
780
+ # maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy()
781
+ maxdist = None
782
+ trajs = trajs.long().detach().cpu().numpy() # S, N, 2
783
+ valids = valids.long().detach().cpu().numpy() # S, N
784
+
785
+ if N > max_show:
786
+ inds = np.random.choice(N, max_show)
787
+ trajs = trajs[:,inds]
788
+ valids = valids[:,inds]
789
+ N = trajs.shape[1]
790
+
791
+ for i in range(min(N, max_show)):
792
+ if cmap=='onediff' and i==0:
793
+ cmap_ = 'spring'
794
+ elif cmap=='onediff':
795
+ cmap_ = 'winter'
796
+ else:
797
+ cmap_ = cmap
798
+ traj = trajs[:,i] # S, 2
799
+ valid = valids[:,i] # S
800
+ if valid[0]==1:
801
+ traj = traj[valid>0]
802
+ rgb_color = self.draw_traj_on_image_py(
803
+ rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth)
804
+
805
+ rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0)
806
+ rgb = preprocess_color(rgb_color)
807
+ return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
808
+
809
+ def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None):
810
+ # all inputs are numpy tensors
811
+ # rgb is 3 x H x W
812
+ # traj is S x 2
813
+
814
+ H, W, C = rgb.shape
815
+ assert(C==3)
816
+
817
+ rgb = rgb.astype(np.uint8).copy()
818
+
819
+ S1, D = traj.shape
820
+ assert(D==2)
821
+
822
+ color_map = cm.get_cmap(cmap)
823
+ S1, D = traj.shape
824
+
825
+ for s in range(S1):
826
+ if val is not None:
827
+ color = np.array(color_map(val)[:3]) * 255 # rgb
828
+ else:
829
+ if maxdist is not None:
830
+ val = (np.sqrt(np.sum((traj[s]-traj[0])**2))/maxdist).clip(0,1)
831
+ color = np.array(color_map(val)[:3]) * 255 # rgb
832
+ else:
833
+ color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
834
+
835
+ if show_lines and s<(S1-1):
836
+ cv2.line(rgb,
837
+ (int(traj[s,0]), int(traj[s,1])),
838
+ (int(traj[s+1,0]), int(traj[s+1,1])),
839
+ color,
840
+ linewidth,
841
+ cv2.LINE_AA)
842
+ if show_dots:
843
+ cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, -1)
844
+
845
+ # if maxdist is not None:
846
+ # val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1)
847
+ # color = np.array(color_map(val)[:3]) * 255 # rgb
848
+ # else:
849
+ # # draw the endpoint of traj, using the next color (which may be the last color)
850
+ # color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb
851
+
852
+ # # emphasize endpoint
853
+ # cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1)
854
+
855
+ return rgb
856
+
857
+
858
+ def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None):
859
+ # all inputs are numpy tensors
860
+ # rgbs is a list of H,W,3
861
+ # traj is S,2
862
+ H, W, C = rgbs[0].shape
863
+ assert(C==3)
864
+
865
+ rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
866
+
867
+ S1, D = traj.shape
868
+ assert(D==2)
869
+
870
+ x = int(np.clip(traj[0,0], 0, W-1))
871
+ y = int(np.clip(traj[0,1], 0, H-1))
872
+ color = rgbs[0][y,x]
873
+ color = (int(color[0]),int(color[1]),int(color[2]))
874
+ for s in range(S):
875
+ # bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb
876
+ # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1)
877
+ cv2.polylines(rgbs[s],
878
+ [traj[:s+1]],
879
+ False,
880
+ color,
881
+ linewidth,
882
+ cv2.LINE_AA)
883
+ return rgbs
884
+
885
+ def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None):
886
+ # all inputs are numpy tensors
887
+ # rgbs is a list of 3,H,W
888
+ # xy is N,2
889
+ H, W, C = rgb.shape
890
+ assert(C==3)
891
+
892
+ rgb = rgb.astype(np.uint8).copy()
893
+
894
+ N, D = xy.shape
895
+ assert(D==2)
896
+
897
+
898
+ xy = xy.astype(np.float32)
899
+ xy[:,0] = np.clip(xy[:,0], 0, W-1)
900
+ xy[:,1] = np.clip(xy[:,1], 0, H-1)
901
+ xy = xy.astype(np.int32)
902
+
903
+
904
+
905
+ if colors is None:
906
+ colors = get_n_colors(N)
907
+
908
+ for n in range(N):
909
+ color = colors[n]
910
+ # print('color', color)
911
+ # color = (color[0]*255).astype(np.uint8)
912
+ color = (int(color[0]),int(color[1]),int(color[2]))
913
+
914
+ # x = int(np.clip(xy[0,0], 0, W-1))
915
+ # y = int(np.clip(xy[0,1], 0, H-1))
916
+ # color_ = rgbs[0][y,x]
917
+ # color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
918
+ # color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
919
+
920
+ cv2.circle(rgb, (int(xy[n,0]), int(xy[n,1])), linewidth, color, 3)
921
+ # vis_color = int(np.squeeze(vis[s])*255)
922
+ # vis_color = (vis_color,vis_color,vis_color)
923
+ # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1)
924
+ return rgb
925
+
926
+ def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None):
927
+ # all inputs are numpy tensors
928
+ # rgbs is a list of 3,H,W
929
+ # traj is S,2
930
+ H, W, C = rgbs[0].shape
931
+ assert(C==3)
932
+
933
+ rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
934
+
935
+ S1, D = traj.shape
936
+ assert(D==2)
937
+
938
+ if cmap is None:
939
+ bremm = ColorMap2d()
940
+ traj_ = traj[0:1].astype(np.float32)
941
+ traj_[:,0] /= float(W)
942
+ traj_[:,1] /= float(H)
943
+ color = bremm(traj_)
944
+ # print('color', color)
945
+ color = (color[0]*255).astype(np.uint8)
946
+ color = (int(color[0]),int(color[1]),int(color[2]))
947
+
948
+ for s in range(S):
949
+ if cmap is not None:
950
+ color_map = cm.get_cmap(cmap)
951
+ # color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb
952
+ color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
953
+ # color = color.astype(np.uint8)
954
+ # color = (color[0], color[1], color[2])
955
+ # print('color', color)
956
+ # import ipdb; ipdb.set_trace()
957
+
958
+ cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+2, color, -1)
959
+ vis_color = int(np.squeeze(vis[s])*255)
960
+ vis_color = (vis_color,vis_color,vis_color)
961
+ cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1)
962
+
963
+ return rgbs
964
+
965
+ def summ_pts_on_rgb(self, name, trajs, rgb, visibs=None, valids=None, frame_id=None, frame_str=None, only_return=False, show_dots=True, colors=None, cmap='coolwarm', linewidth=1, max_show=1024, already_sorted=False):
966
+ # trajs is B, S, N, 2
967
+ # rgbs is B, S, C, H, W
968
+ B, C, H, W = rgb.shape
969
+ B, S, N, D = trajs.shape
970
+
971
+ rgb = rgb[0] # C, H, W
972
+ trajs = trajs[0] # S, N, 2
973
+ if valids is None:
974
+ valids = torch.ones_like(trajs[:,:,0]) # S, N
975
+ else:
976
+ valids = valids[0]
977
+ if visibs is None:
978
+ visibs = torch.ones_like(trajs[:,:,0]) # S, N
979
+ else:
980
+ visibs = visibs[0]
981
+
982
+ trajs = trajs.clamp(-16, W+16)
983
+
984
+ if N > max_show:
985
+ inds = np.random.choice(N, max_show)
986
+ trajs = trajs[:,inds]
987
+ valids = valids[:,inds]
988
+ visibs = visibs[:,inds]
989
+ N = trajs.shape[1]
990
+
991
+ if not already_sorted:
992
+ inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0))
993
+ trajs = trajs[:,inds]
994
+ valids = valids[:,inds]
995
+ visibs = visibs[:,inds]
996
+
997
+ rgb = back2color(rgb).detach().cpu().numpy()
998
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
999
+
1000
+ trajs = trajs.long().detach().cpu().numpy() # S, N, 2
1001
+ valids = valids.long().detach().cpu().numpy() # S, N
1002
+ visibs = visibs.long().detach().cpu().numpy() # S, N
1003
+
1004
+ rgb = rgb.astype(np.uint8).copy()
1005
+
1006
+ for i in range(min(N, max_show)):
1007
+ if cmap=='onediff' and i==0:
1008
+ cmap_ = 'spring'
1009
+ elif cmap=='onediff':
1010
+ cmap_ = 'winter'
1011
+ else:
1012
+ cmap_ = cmap
1013
+ traj = trajs[:,i] # S,2
1014
+ valid = valids[:,i] # S
1015
+ visib = visibs[:,i] # S
1016
+
1017
+ if colors is None:
1018
+ ii = i/(1e-4+N-1.0)
1019
+ color_map = cm.get_cmap(cmap)
1020
+ color = np.array(color_map(ii)[:3]) * 255 # rgb
1021
+ else:
1022
+ color = np.array(colors[i]).astype(np.int64)
1023
+ color = (int(color[0]),int(color[1]),int(color[2]))
1024
+
1025
+ for s in range(S):
1026
+ if valid[s]:
1027
+ if visib[s]:
1028
+ thickness = -1
1029
+ else:
1030
+ thickness = 2
1031
+ cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, thickness)
1032
+ rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0)
1033
+ rgb = preprocess_color(rgb)
1034
+ return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
1035
+
1036
+ def summ_pts_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', colors=None, linewidth=1, max_show=1024, frame_strs=None):
1037
+ # trajs is B, S, N, 2
1038
+ # rgbs is B, S, C, H, W
1039
+ B, S, C, H, W = rgbs.shape
1040
+ B, S2, N, D = trajs.shape
1041
+ assert(S==S2)
1042
+
1043
+ rgbs = rgbs[0] # S, C, H, W
1044
+ trajs = trajs[0] # S, N, 2
1045
+ if valids is None:
1046
+ valids = torch.ones_like(trajs[:,:,0]) # S, N
1047
+ else:
1048
+ valids = valids[0]
1049
+ if visibs is None:
1050
+ visibs = torch.ones_like(trajs[:,:,0]) # S, N
1051
+ else:
1052
+ visibs = visibs[0]
1053
+
1054
+ if N > max_show:
1055
+ inds = np.random.choice(N, max_show)
1056
+ trajs = trajs[:,inds]
1057
+ valids = valids[:,inds]
1058
+ visibs = visibs[:,inds]
1059
+ N = trajs.shape[1]
1060
+ inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0))
1061
+ trajs = trajs[:,inds]
1062
+ valids = valids[:,inds]
1063
+ visibs = visibs[:,inds]
1064
+
1065
+ rgbs_color = []
1066
+ for rgb in rgbs:
1067
+ rgb = back2color(rgb).detach().cpu().numpy()
1068
+ rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
1069
+ rgbs_color.append(rgb) # each element 3 x H x W
1070
+
1071
+ trajs = trajs.long().detach().cpu().numpy() # S, N, 2
1072
+ valids = valids.long().detach().cpu().numpy() # S, N
1073
+ visibs = visibs.long().detach().cpu().numpy() # S, N
1074
+
1075
+ rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color]
1076
+
1077
+ for i in range(min(N, max_show)):
1078
+ traj = trajs[:,i] # S,2
1079
+ valid = valids[:,i] # S
1080
+ visib = visibs[:,i] # S
1081
+
1082
+ if colors is None:
1083
+ ii = i/(1e-4+N-1.0)
1084
+ color_map = cm.get_cmap(cmap)
1085
+ color = np.array(color_map(ii)[:3]) * 255 # rgb
1086
+ else:
1087
+ color = np.array(colors[i]).astype(np.int64)
1088
+ color = (int(color[0]),int(color[1]),int(color[2]))
1089
+
1090
+ for s in range(S):
1091
+ if valid[s]:
1092
+ if visib[s]:
1093
+ thickness = -1
1094
+ else:
1095
+ thickness = 2
1096
+ cv2.circle(rgbs_color[s], (int(traj[s,0]), int(traj[s,1])), int(linewidth), color, thickness)
1097
+ rgbs = []
1098
+ for rgb in rgbs_color:
1099
+ rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
1100
+ rgbs.append(preprocess_color(rgb))
1101
+
1102
+ return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
1103
+
utils/loss.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from typing import List
5
+ import utils.basic
6
+
7
+
8
+ def sequence_loss(
9
+ flow_preds,
10
+ flow_gt,
11
+ valids,
12
+ vis=None,
13
+ gamma=0.8,
14
+ use_huber_loss=False,
15
+ loss_only_for_visible=False,
16
+ ):
17
+ """Loss function defined over sequence of flow predictions"""
18
+ total_flow_loss = 0.0
19
+ for j in range(len(flow_gt)):
20
+ B, S, N, D = flow_gt[j].shape
21
+ B, S2, N = valids[j].shape
22
+ assert S == S2
23
+ n_predictions = len(flow_preds[j])
24
+ flow_loss = 0.0
25
+ for i in range(n_predictions):
26
+ i_weight = gamma ** (n_predictions - i - 1)
27
+ flow_pred = flow_preds[j][i]
28
+ if use_huber_loss:
29
+ i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0)
30
+ else:
31
+ i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
32
+ i_loss = torch.mean(i_loss, dim=3) # B, S, N
33
+ valid_ = valids[j].clone()
34
+ if loss_only_for_visible:
35
+ valid_ = valid_ * vis[j]
36
+ flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss, valid_)
37
+ flow_loss = flow_loss / n_predictions
38
+ total_flow_loss += flow_loss
39
+ return total_flow_loss / len(flow_gt)
40
+
41
+ def sequence_loss_dense(
42
+ flow_preds,
43
+ flow_gt,
44
+ valids,
45
+ vis=None,
46
+ gamma=0.8,
47
+ use_huber_loss=False,
48
+ loss_only_for_visible=False,
49
+ ):
50
+ """Loss function defined over sequence of flow predictions"""
51
+ total_flow_loss = 0.0
52
+ for j in range(len(flow_gt)):
53
+ # print('flow_gt[j]', flow_gt[j].shape)
54
+ B, S, D, H, W = flow_gt[j].shape
55
+ B, S2, _, H, W = valids[j].shape
56
+ assert S == S2
57
+ n_predictions = len(flow_preds[j])
58
+ flow_loss = 0.0
59
+ # import ipdb; ipdb.set_trace()
60
+ for i in range(n_predictions):
61
+ # print('flow_e[j][i]', flow_preds[j][i].shape)
62
+ i_weight = gamma ** (n_predictions - i - 1)
63
+ flow_pred = flow_preds[j][i] # B,S,2,H,W
64
+ if use_huber_loss:
65
+ i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0) # B,S,2,H,W
66
+ else:
67
+ i_loss = (flow_pred - flow_gt[j]).abs() # B,S,2,H,W
68
+ i_loss_ = torch.mean(i_loss, dim=2) # B,S,H,W
69
+ valid_ = valids[j].reshape(B,S,H,W)
70
+ # print(' (%d,%d) i_loss_' % (i,j), i_loss_.shape)
71
+ # print(' (%d,%d) valid_' % (i,j), valid_.shape)
72
+ if loss_only_for_visible:
73
+ valid_ = valid_ * vis[j].reshape(B,-1,H,W) # usually B,S,H,W, but maybe B,1,H,W
74
+ flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss_, valid_, broadcast=True)
75
+ # import ipdb; ipdb.set_trace()
76
+ flow_loss = flow_loss / n_predictions
77
+ total_flow_loss += flow_loss
78
+ return total_flow_loss / len(flow_gt)
79
+
80
+
81
+ def huber_loss(x, y, delta=1.0):
82
+ """Calculate element-wise Huber loss between x and y"""
83
+ diff = x - y
84
+ abs_diff = diff.abs()
85
+ flag = (abs_diff <= delta).float()
86
+ return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)
87
+
88
+
89
+ def sequence_BCE_loss(vis_preds, vis_gts, valids=None, use_logits=False):
90
+ total_bce_loss = 0.0
91
+ # all_vis_preds = [torch.stack(vp) for vp in vis_preds]
92
+ # all_vis_preds = torch.stack(all_vis_preds)
93
+ # utils.basic.print_stats('all_vis_preds', all_vis_preds)
94
+ for j in range(len(vis_preds)):
95
+ n_predictions = len(vis_preds[j])
96
+ bce_loss = 0.0
97
+ for i in range(n_predictions):
98
+ # utils.basic.print_stats('vis_preds[%d][%d]' % (j,i), vis_preds[j][i])
99
+ # utils.basic.print_stats('vis_gts[%d]' % (i), vis_gts[i])
100
+ if use_logits:
101
+ loss = F.binary_cross_entropy_with_logits(vis_preds[j][i], vis_gts[j], reduction='none')
102
+ else:
103
+ loss = F.binary_cross_entropy(vis_preds[j][i], vis_gts[j], reduction='none')
104
+ if valids is None:
105
+ bce_loss += loss.mean()
106
+ else:
107
+ bce_loss += (loss * valids[j]).mean()
108
+ bce_loss = bce_loss / n_predictions
109
+ total_bce_loss += bce_loss
110
+ return total_bce_loss / len(vis_preds)
111
+
112
+
113
+ # def sequence_BCE_loss_dense(vis_preds, vis_gts):
114
+ # total_bce_loss = 0.0
115
+ # for j in range(len(vis_preds)):
116
+ # n_predictions = len(vis_preds[j])
117
+ # bce_loss = 0.0
118
+ # for i in range(n_predictions):
119
+ # vis_e = vis_preds[j][i]
120
+ # vis_g = vis_gts[j]
121
+ # print('vis_e', vis_e.shape, 'vis_g', vis_g.shape)
122
+ # vis_loss = F.binary_cross_entropy(vis_e, vis_g)
123
+ # bce_loss += vis_loss
124
+ # bce_loss = bce_loss / n_predictions
125
+ # total_bce_loss += bce_loss
126
+ # return total_bce_loss / len(vis_preds)
127
+
128
+
129
+ def sequence_prob_loss(
130
+ tracks: torch.Tensor,
131
+ confidence: torch.Tensor,
132
+ target_points: torch.Tensor,
133
+ visibility: torch.Tensor,
134
+ expected_dist_thresh: float = 12.0,
135
+ use_logits=False,
136
+ ):
137
+ """Loss for classifying if a point is within pixel threshold of its target."""
138
+ # Points with an error larger than 12 pixels are likely to be useless; marking
139
+ # them as occluded will actually improve Jaccard metrics and give
140
+ # qualitatively better results.
141
+ total_logprob_loss = 0.0
142
+ for j in range(len(tracks)):
143
+ n_predictions = len(tracks[j])
144
+ logprob_loss = 0.0
145
+ for i in range(n_predictions):
146
+ err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=-1)
147
+ valid = (err <= expected_dist_thresh**2).float()
148
+ if use_logits:
149
+ loss = F.binary_cross_entropy_with_logits(confidence[j][i], valid, reduction="none")
150
+ else:
151
+ loss = F.binary_cross_entropy(confidence[j][i], valid, reduction="none")
152
+ loss *= visibility[j]
153
+ loss = torch.mean(loss, dim=[1, 2])
154
+ logprob_loss += loss
155
+ logprob_loss = logprob_loss / n_predictions
156
+ total_logprob_loss += logprob_loss
157
+ return total_logprob_loss / len(tracks)
158
+
159
+ def sequence_prob_loss_dense(
160
+ tracks: torch.Tensor,
161
+ confidence: torch.Tensor,
162
+ target_points: torch.Tensor,
163
+ visibility: torch.Tensor,
164
+ expected_dist_thresh: float = 12.0,
165
+ use_logits=False,
166
+ ):
167
+ """Loss for classifying if a point is within pixel threshold of its target."""
168
+ # Points with an error larger than 12 pixels are likely to be useless; marking
169
+ # them as occluded will actually improve Jaccard metrics and give
170
+ # qualitatively better results.
171
+
172
+ # all_confidence = [torch.stack(vp) for vp in confidence]
173
+ # all_confidence = torch.stack(all_confidence)
174
+ # utils.basic.print_stats('all_confidence', all_confidence)
175
+
176
+ total_logprob_loss = 0.0
177
+ for j in range(len(tracks)):
178
+ n_predictions = len(tracks[j])
179
+ logprob_loss = 0.0
180
+ for i in range(n_predictions):
181
+ # print('trajs_e', tracks[j][i].shape)
182
+ # print('trajs_g', target_points[j].shape)
183
+ err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=2)
184
+ positive = (err <= expected_dist_thresh**2).float()
185
+ # print('conf', confidence[j][i].shape, 'positive', positive.shape)
186
+ if use_logits:
187
+ loss = F.binary_cross_entropy_with_logits(confidence[j][i].squeeze(2), positive, reduction="none")
188
+ else:
189
+ loss = F.binary_cross_entropy(confidence[j][i].squeeze(2), positive, reduction="none")
190
+ loss *= visibility[j].squeeze(2) # B,S,H,W
191
+ loss = torch.mean(loss, dim=[1,2,3])
192
+ logprob_loss += loss
193
+ logprob_loss = logprob_loss / n_predictions
194
+ total_logprob_loss += logprob_loss
195
+ return total_logprob_loss / len(tracks)
196
+
197
+
198
+ def masked_mean(data, mask, dim):
199
+ if mask is None:
200
+ return data.mean(dim=dim, keepdim=True)
201
+ mask = mask.float()
202
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
203
+ mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
204
+ mask_sum, min=1.0
205
+ )
206
+ return mask_mean
207
+
208
+
209
+ def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
210
+ if mask is None:
211
+ return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
212
+ mask = mask.float()
213
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
214
+ mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
215
+ mask_sum, min=1.0
216
+ )
217
+ mask_var = torch.sum(
218
+ mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
219
+ ) / torch.clamp(mask_sum, min=1.0)
220
+ return mask_mean.squeeze(dim), mask_var.squeeze(dim)
utils/misc.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, positions):
5
+ assert embed_dim % 2 == 0
6
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
7
+ omega /= embed_dim / 2.0
8
+ omega = 1.0 / 10000**omega # (D/2,)
9
+
10
+ positions = positions.reshape(-1) # (M,)
11
+ out = torch.einsum("m,d->md", positions, omega) # (M, D/2), outer product
12
+
13
+ emb_sin = torch.sin(out) # (M, D/2)
14
+ emb_cos = torch.cos(out) # (M, D/2)
15
+
16
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
17
+ return emb[None].float()
18
+
19
+
20
+ class SimplePool():
21
+ def __init__(self, pool_size, version='pt', min_size=1):
22
+ self.pool_size = pool_size
23
+ self.version = version
24
+ self.items = []
25
+ self.min_size = min_size
26
+
27
+ if not (version=='pt' or version=='np'):
28
+ print('version = %s; please choose pt or np')
29
+ assert(False) # please choose pt or np
30
+
31
+ def __len__(self):
32
+ return len(self.items)
33
+
34
+ def mean(self, min_size=None):
35
+ if min_size is None:
36
+ pool_size_thresh = self.min_size
37
+ elif min_size=='half':
38
+ pool_size_thresh = self.pool_size/2
39
+ else:
40
+ pool_size_thresh = min_size
41
+
42
+ if self.version=='np':
43
+ if len(self.items) >= pool_size_thresh:
44
+ return np.sum(self.items)/float(len(self.items))
45
+ else:
46
+ return np.nan
47
+ if self.version=='pt':
48
+ if len(self.items) >= pool_size_thresh:
49
+ return torch.sum(self.items)/float(len(self.items))
50
+ else:
51
+ return torch.from_numpy(np.nan)
52
+
53
+ def sample(self, with_replacement=True):
54
+ idx = np.random.randint(len(self.items))
55
+ if with_replacement:
56
+ return self.items[idx]
57
+ else:
58
+ return self.items.pop(idx)
59
+
60
+ def fetch(self, num=None):
61
+ if self.version=='pt':
62
+ item_array = torch.stack(self.items)
63
+ elif self.version=='np':
64
+ item_array = np.stack(self.items)
65
+ if num is not None:
66
+ # there better be some items
67
+ assert(len(self.items) >= num)
68
+
69
+ # if there are not that many elements just return however many there are
70
+ if len(self.items) < num:
71
+ return item_array
72
+ else:
73
+ idxs = np.random.randint(len(self.items), size=num)
74
+ return item_array[idxs]
75
+ else:
76
+ return item_array
77
+
78
+ def is_full(self):
79
+ full = len(self.items)==self.pool_size
80
+ return full
81
+
82
+ def empty(self):
83
+ self.items = []
84
+
85
+ def have_min_size(self):
86
+ return len(self.items) >= self.min_size
87
+
88
+
89
+ def update(self, items):
90
+ for item in items:
91
+ if len(self.items) < self.pool_size:
92
+ # the pool is not full, so let's add this in
93
+ self.items.append(item)
94
+ else:
95
+ # the pool is full
96
+ # pop from the front
97
+ self.items.pop(0)
98
+ # add to the back
99
+ self.items.append(item)
100
+ return self.items
utils/py.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob, math
2
+ import numpy as np
3
+ # from scipy import misc
4
+ # from scipy import linalg
5
+ from PIL import Image
6
+ import io
7
+ import matplotlib.pyplot as plt
8
+ EPS = 1e-6
9
+
10
+
11
+ XMIN = -64.0 # right (neg is left)
12
+ XMAX = 64.0 # right
13
+ YMIN = -64.0 # down (neg is up)
14
+ YMAX = 64.0 # down
15
+ ZMIN = -64.0 # forward
16
+ ZMAX = 64.0 # forward
17
+
18
+ def print_stats(name, tensor):
19
+ tensor = tensor.astype(np.float32)
20
+ print('%s min = %.2f, mean = %.2f, max = %.2f' % (name, np.min(tensor), np.mean(tensor), np.max(tensor)), tensor.shape)
21
+
22
+ def reduce_masked_mean(x, mask, axis=None, keepdims=False):
23
+ # x and mask are the same shape
24
+ # returns shape-1
25
+ # axis can be a list of axes
26
+ prod = x*mask
27
+ numer = np.sum(prod, axis=axis, keepdims=keepdims)
28
+ denom = EPS+np.sum(mask, axis=axis, keepdims=keepdims)
29
+ mean = numer/denom
30
+ return mean
31
+
32
+ def reduce_masked_sum(x, mask, axis=None, keepdims=False):
33
+ # x and mask are the same shape
34
+ # returns shape-1
35
+ # axis can be a list of axes
36
+ prod = x*mask
37
+ numer = np.sum(prod, axis=axis, keepdims=keepdims)
38
+ return numer
39
+
40
+ def reduce_masked_median(x, mask, keep_batch=False):
41
+ # x and mask are the same shape
42
+ # returns shape-1
43
+ # axis can be a list of axes
44
+
45
+ if not (x.shape == mask.shape):
46
+ print('reduce_masked_median: these shapes should match:', x.shape, mask.shape)
47
+ assert(False)
48
+ # assert(x.shape == mask.shape)
49
+
50
+ B = list(x.shape)[0]
51
+
52
+ if keep_batch:
53
+ x = np.reshape(x, [B, -1])
54
+ mask = np.reshape(mask, [B, -1])
55
+ meds = np.zeros([B], np.float32)
56
+ for b in list(range(B)):
57
+ xb = x[b]
58
+ mb = mask[b]
59
+ if np.sum(mb) > 0:
60
+ xb = xb[mb > 0]
61
+ meds[b] = np.median(xb)
62
+ else:
63
+ meds[b] = np.nan
64
+ return meds
65
+ else:
66
+ x = np.reshape(x, [-1])
67
+ mask = np.reshape(mask, [-1])
68
+ if np.sum(mask) > 0:
69
+ x = x[mask > 0]
70
+ med = np.median(x)
71
+ else:
72
+ med = np.nan
73
+ med = np.array([med], np.float32)
74
+ return med
75
+
76
+ def get_nFiles(path):
77
+ return len(glob.glob(path))
78
+
79
+ def get_file_list(path):
80
+ return glob.glob(path)
81
+
82
+ def rotm2eul(R):
83
+ # R is 3x3
84
+ sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0])
85
+ if sy > 1e-6: # singular
86
+ x = math.atan2(R[2,1] , R[2,2])
87
+ y = math.atan2(-R[2,0], sy)
88
+ z = math.atan2(R[1,0], R[0,0])
89
+ else:
90
+ x = math.atan2(-R[1,2], R[1,1])
91
+ y = math.atan2(-R[2,0], sy)
92
+ z = 0
93
+ return x, y, z
94
+
95
+ def rad2deg(rad):
96
+ return rad*180.0/np.pi
97
+
98
+ def deg2rad(deg):
99
+ return deg/180.0*np.pi
100
+
101
+ def eul2rotm(rx, ry, rz):
102
+ # copy of matlab, but order of inputs is different
103
+ # R = [ cy*cz sy*sx*cz-sz*cx sy*cx*cz+sz*sx
104
+ # cy*sz sy*sx*sz+cz*cx sy*cx*sz-cz*sx
105
+ # -sy cy*sx cy*cx]
106
+ sinz = np.sin(rz)
107
+ siny = np.sin(ry)
108
+ sinx = np.sin(rx)
109
+ cosz = np.cos(rz)
110
+ cosy = np.cos(ry)
111
+ cosx = np.cos(rx)
112
+ r11 = cosy*cosz
113
+ r12 = sinx*siny*cosz - cosx*sinz
114
+ r13 = cosx*siny*cosz + sinx*sinz
115
+ r21 = cosy*sinz
116
+ r22 = sinx*siny*sinz + cosx*cosz
117
+ r23 = cosx*siny*sinz - sinx*cosz
118
+ r31 = -siny
119
+ r32 = sinx*cosy
120
+ r33 = cosx*cosy
121
+ r1 = np.stack([r11,r12,r13],axis=-1)
122
+ r2 = np.stack([r21,r22,r23],axis=-1)
123
+ r3 = np.stack([r31,r32,r33],axis=-1)
124
+ r = np.stack([r1,r2,r3],axis=0)
125
+ return r
126
+
127
+ def wrap2pi(rad_angle):
128
+ # puts the angle into the range [-pi, pi]
129
+ return np.arctan2(np.sin(rad_angle), np.cos(rad_angle))
130
+
131
+ def rot2view(rx,ry,rz,x,y,z):
132
+ # takes rot angles and 3d position as input
133
+ # returns viewpoint angles as output
134
+ # (all in radians)
135
+ # it will perform strangely if z <= 0
136
+ az = wrap2pi(ry - (-np.arctan2(z, x) - 1.5*np.pi))
137
+ el = -wrap2pi(rx - (-np.arctan2(z, y) - 1.5*np.pi))
138
+ th = -rz
139
+ return az, el, th
140
+
141
+ def invAxB(a,b):
142
+ """
143
+ Compute the relative 3d transformation between a and b.
144
+
145
+ Input:
146
+ a -- first pose (homogeneous 4x4 matrix)
147
+ b -- second pose (homogeneous 4x4 matrix)
148
+
149
+ Output:
150
+ Relative 3d transformation from a to b.
151
+ """
152
+ return np.dot(np.linalg.inv(a),b)
153
+
154
+ def merge_rt(r, t):
155
+ # r is 3 x 3
156
+ # t is 3 or maybe 3 x 1
157
+ t = np.reshape(t, [3, 1])
158
+ rt = np.concatenate((r,t), axis=1)
159
+ # rt is 3 x 4
160
+ br = np.reshape(np.array([0,0,0,1], np.float32), [1, 4])
161
+ # br is 1 x 4
162
+ rt = np.concatenate((rt, br), axis=0)
163
+ # rt is 4 x 4
164
+ return rt
165
+
166
+ def split_rt(rt):
167
+ r = rt[:3,:3]
168
+ t = rt[:3,3]
169
+ r = np.reshape(r, [3, 3])
170
+ t = np.reshape(t, [3, 1])
171
+ return r, t
172
+
173
+ def split_intrinsics(K):
174
+ # K is 3 x 4 or 4 x 4
175
+ fx = K[0,0]
176
+ fy = K[1,1]
177
+ x0 = K[0,2]
178
+ y0 = K[1,2]
179
+ return fx, fy, x0, y0
180
+
181
+ def merge_intrinsics(fx, fy, x0, y0):
182
+ # inputs are shaped []
183
+ K = np.eye(4)
184
+ K[0,0] = fx
185
+ K[1,1] = fy
186
+ K[0,2] = x0
187
+ K[1,2] = y0
188
+ # K is shaped 4 x 4
189
+ return K
190
+
191
+ def scale_intrinsics(K, sx, sy):
192
+ fx, fy, x0, y0 = split_intrinsics(K)
193
+ fx *= sx
194
+ fy *= sy
195
+ x0 *= sx
196
+ y0 *= sy
197
+ return merge_intrinsics(fx, fy, x0, y0)
198
+
199
+ # def meshgrid(H, W):
200
+ # x = np.linspace(0, W-1, W)
201
+ # y = np.linspace(0, H-1, H)
202
+ # xv, yv = np.meshgrid(x, y)
203
+ # return xv, yv
204
+
205
+ def compute_distance(transform):
206
+ """
207
+ Compute the distance of the translational component of a 4x4 homogeneous matrix.
208
+ """
209
+ return numpy.linalg.norm(transform[0:3,3])
210
+
211
+ def radian_l1_dist(e, g):
212
+ # if our angles are in [0, 360] we can follow this stack overflow answer:
213
+ # https://gamedev.stackexchange.com/questions/4467/comparing-angles-and-working-out-the-difference
214
+ # wrap2pi brings the angles to [-180, 180]; adding pi puts them in [0, 360]
215
+ e = wrap2pi(e)+np.pi
216
+ g = wrap2pi(g)+np.pi
217
+ l = np.abs(np.pi - np.abs(np.abs(e-g) - np.pi))
218
+ return l
219
+
220
+ def apply_pix_T_cam(pix_T_cam, xyz):
221
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
222
+ # xyz is shaped B x H*W x 3
223
+ # returns xy, shaped B x H*W x 2
224
+ N, C = xyz.shape
225
+ x, y, z = np.split(xyz, 3, axis=-1)
226
+ EPS = 1e-4
227
+ z = np.clip(z, EPS, None)
228
+ x = (x*fx)/(z)+x0
229
+ y = (y*fy)/(z)+y0
230
+ xy = np.concatenate([x, y], axis=-1)
231
+ return xy
232
+
233
+ def apply_4x4(RT, XYZ):
234
+ # RT is 4 x 4
235
+ # XYZ is N x 3
236
+
237
+ # put into homogeneous coords
238
+ X, Y, Z = np.split(XYZ, 3, axis=1)
239
+ ones = np.ones_like(X)
240
+ XYZ1 = np.concatenate([X, Y, Z, ones], axis=1)
241
+ # XYZ1 is N x 4
242
+
243
+ XYZ1_t = np.transpose(XYZ1)
244
+ # this is 4 x N
245
+
246
+ XYZ2_t = np.dot(RT, XYZ1_t)
247
+ # this is 4 x N
248
+
249
+ XYZ2 = np.transpose(XYZ2_t)
250
+ # this is N x 4
251
+
252
+ XYZ2 = XYZ2[:,:3]
253
+ # this is N x 3
254
+
255
+ return XYZ2
256
+
257
+ def Ref2Mem(xyz, Z, Y, X):
258
+ # xyz is N x 3, in ref coordinates
259
+ # transforms ref coordinates into mem coordinates
260
+ N, C = xyz.shape
261
+ assert(C==3)
262
+ mem_T_ref = get_mem_T_ref(Z, Y, X)
263
+ xyz = apply_4x4(mem_T_ref, xyz)
264
+ return xyz
265
+
266
+ # def Mem2Ref(xyz_mem, MH, MW, MD):
267
+ # # xyz is B x N x 3, in mem coordinates
268
+ # # transforms mem coordinates into ref coordinates
269
+ # B, N, C = xyz_mem.get_shape().as_list()
270
+ # ref_T_mem = get_ref_T_mem(B, MH, MW, MD)
271
+ # xyz_ref = utils_geom.apply_4x4(ref_T_mem, xyz_mem)
272
+ # return xyz_ref
273
+
274
+ def get_mem_T_ref(Z, Y, X):
275
+ # sometimes we want the mat itself
276
+ # note this is not a rigid transform
277
+
278
+ # for interpretability, let's construct this in two steps...
279
+
280
+ # translation
281
+ center_T_ref = np.eye(4, dtype=np.float32)
282
+ center_T_ref[0,3] = -XMIN
283
+ center_T_ref[1,3] = -YMIN
284
+ center_T_ref[2,3] = -ZMIN
285
+
286
+ VOX_SIZE_X = (XMAX-XMIN)/float(X)
287
+ VOX_SIZE_Y = (YMAX-YMIN)/float(Y)
288
+ VOX_SIZE_Z = (ZMAX-ZMIN)/float(Z)
289
+
290
+ # scaling
291
+ mem_T_center = np.eye(4, dtype=np.float32)
292
+ mem_T_center[0,0] = 1./VOX_SIZE_X
293
+ mem_T_center[1,1] = 1./VOX_SIZE_Y
294
+ mem_T_center[2,2] = 1./VOX_SIZE_Z
295
+
296
+ mem_T_ref = np.dot(mem_T_center, center_T_ref)
297
+ return mem_T_ref
298
+
299
+ def safe_inverse(a):
300
+ r, t = split_rt(a)
301
+ t = np.reshape(t, [3, 1])
302
+ r_transpose = r.T
303
+ inv = np.concatenate([r_transpose, -np.matmul(r_transpose, t)], 1)
304
+ bottom_row = a[3:4, :] # this is [0, 0, 0, 1]
305
+ inv = np.concatenate([inv, bottom_row], 0)
306
+ return inv
307
+
308
+ def get_ref_T_mem(Z, Y, X):
309
+ mem_T_ref = get_mem_T_ref(X, Y, X)
310
+ # note safe_inverse is inapplicable here,
311
+ # since the transform is nonrigid
312
+ ref_T_mem = np.linalg.inv(mem_T_ref)
313
+ return ref_T_mem
314
+
315
+ def voxelize_xyz(xyz_ref, Z, Y, X):
316
+ # xyz_ref is N x 3
317
+ xyz_mem = Ref2Mem(xyz_ref, Z, Y, X)
318
+ # this is N x 3
319
+ voxels = get_occupancy(xyz_mem, Z, Y, X)
320
+ voxels = np.reshape(voxels, [Z, Y, X, 1])
321
+ return voxels
322
+
323
+ def get_inbounds(xyz, Z, Y, X, already_mem=False):
324
+ # xyz is H*W x 3
325
+
326
+ if not already_mem:
327
+ xyz = Ref2Mem(xyz, Z, Y, X)
328
+
329
+ x_valid = np.logical_and(
330
+ np.greater_equal(xyz[:,0], -0.5),
331
+ np.less(xyz[:,0], float(X)-0.5))
332
+ y_valid = np.logical_and(
333
+ np.greater_equal(xyz[:,1], -0.5),
334
+ np.less(xyz[:,1], float(Y)-0.5))
335
+ z_valid = np.logical_and(
336
+ np.greater_equal(xyz[:,2], -0.5),
337
+ np.less(xyz[:,2], float(Z)-0.5))
338
+ inbounds = np.logical_and(np.logical_and(x_valid, y_valid), z_valid)
339
+ return inbounds
340
+
341
+ def sub2ind3d_zyx(depth, height, width, d, h, w):
342
+ # same as sub2ind3d, but inputs in zyx order
343
+ # when gathering/scattering with these inds, the tensor should be Z x Y x X
344
+ return d*height*width + h*width + w
345
+
346
+ def sub2ind3d_yxz(height, width, depth, h, w, d):
347
+ return h*width*depth + w*depth + d
348
+
349
+ # def ind2sub(height, width, ind):
350
+ # # int input
351
+ # y = int(ind / height)
352
+ # x = ind % height
353
+ # return y, x
354
+
355
+ def get_occupancy(xyz_mem, Z, Y, X):
356
+ # xyz_mem is N x 3
357
+ # we want to fill a voxel tensor with 1's at these inds
358
+
359
+ inbounds = get_inbounds(xyz_mem, Z, Y, X, already_mem=True)
360
+ inds = np.where(inbounds)
361
+
362
+ xyz_mem = np.reshape(xyz_mem[inds], [-1, 3])
363
+ # xyz_mem is N x 3
364
+
365
+ # this is more accurate than a cast/floor, but runs into issues when Y==0
366
+ xyz_mem = np.round(xyz_mem).astype(np.int32)
367
+ x = xyz_mem[:,0]
368
+ y = xyz_mem[:,1]
369
+ z = xyz_mem[:,2]
370
+
371
+ voxels = np.zeros([Z, Y, X], np.float32)
372
+ voxels[z, y, x] = 1.0
373
+
374
+ return voxels
375
+
376
+ def pixels2camera(x,y,z,fx,fy,x0,y0):
377
+ # x and y are locations in pixel coordinates, z is a depth image in meters
378
+ # their shapes are H x W
379
+ # fx, fy, x0, y0 are scalar camera intrinsics
380
+ # returns xyz, sized [B,H*W,3]
381
+
382
+ H, W = z.shape
383
+
384
+ fx = np.reshape(fx, [1,1])
385
+ fy = np.reshape(fy, [1,1])
386
+ x0 = np.reshape(x0, [1,1])
387
+ y0 = np.reshape(y0, [1,1])
388
+
389
+ # unproject
390
+ x = ((z+EPS)/fx)*(x-x0)
391
+ y = ((z+EPS)/fy)*(y-y0)
392
+
393
+ x = np.reshape(x, [-1])
394
+ y = np.reshape(y, [-1])
395
+ z = np.reshape(z, [-1])
396
+ xyz = np.stack([x,y,z], axis=1)
397
+ return xyz
398
+
399
+ def depth2pointcloud(z, pix_T_cam):
400
+ H = z.shape[0]
401
+ W = z.shape[1]
402
+ y, x = meshgrid2d(H, W)
403
+ z = np.reshape(z, [H, W])
404
+
405
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
406
+ xyz = pixels2camera(x, y, z, fx, fy, x0, y0)
407
+ return xyz
408
+
409
+ def meshgrid2d(Y, X):
410
+ grid_y = np.linspace(0.0, Y-1, Y)
411
+ grid_y = np.reshape(grid_y, [Y, 1])
412
+ grid_y = np.tile(grid_y, [1, X])
413
+
414
+ grid_x = np.linspace(0.0, X-1, X)
415
+ grid_x = np.reshape(grid_x, [1, X])
416
+ grid_x = np.tile(grid_x, [Y, 1])
417
+
418
+ # outputs are Y x X
419
+ return grid_y, grid_x
420
+
421
+ def gridcloud3d(Y, X, Z):
422
+ x_ = np.linspace(0, X-1, X)
423
+ y_ = np.linspace(0, Y-1, Y)
424
+ z_ = np.linspace(0, Z-1, Z)
425
+ y, x, z = np.meshgrid(y_, x_, z_, indexing='ij')
426
+ x = np.reshape(x, [-1])
427
+ y = np.reshape(y, [-1])
428
+ z = np.reshape(z, [-1])
429
+ xyz = np.stack([x,y,z], axis=1).astype(np.float32)
430
+ return xyz
431
+
432
+ def gridcloud2d(Y, X):
433
+ x_ = np.linspace(0, X-1, X)
434
+ y_ = np.linspace(0, Y-1, Y)
435
+ y, x = np.meshgrid(y_, x_, indexing='ij')
436
+ x = np.reshape(x, [-1])
437
+ y = np.reshape(y, [-1])
438
+ xy = np.stack([x,y], axis=1).astype(np.float32)
439
+ return xy
440
+
441
+ def normalize(im):
442
+ im = im - np.min(im)
443
+ im = im / np.max(im)
444
+ return im
445
+
446
+ def wrap2pi(rad_angle):
447
+ # rad_angle can be any shape
448
+ # puts the angle into the range [-pi, pi]
449
+ return np.arctan2(np.sin(rad_angle), np.cos(rad_angle))
450
+
451
+ def convert_occ_to_height(occ):
452
+ Z, Y, X, C = occ.shape
453
+ assert(C==1)
454
+
455
+ height = np.linspace(float(Y), 1.0, Y)
456
+ height = np.reshape(height, [1, Y, 1, 1])
457
+ height = np.max(occ*height, axis=1)/float(Y)
458
+ height = np.reshape(height, [Z, X, C])
459
+ return height
460
+
461
+ def create_depth_image(xy, Z, H, W):
462
+
463
+ # turn the xy coordinates into image inds
464
+ xy = np.round(xy)
465
+
466
+ # lidar reports a sphere of measurements
467
+ # only use the inds that are within the image bounds
468
+ # also, only use forward-pointing depths (Z > 0)
469
+ valid = (xy[:,0] < W-1) & (xy[:,1] < H-1) & (xy[:,0] >= 0) & (xy[:,1] >= 0) & (Z[:] > 0)
470
+
471
+ # gather these up
472
+ xy = xy[valid]
473
+ Z = Z[valid]
474
+
475
+ inds = sub2ind(H,W,xy[:,1],xy[:,0])
476
+ depth = np.zeros((H*W), np.float32)
477
+
478
+ for (index, replacement) in zip(inds, Z):
479
+ depth[index] = replacement
480
+ depth[np.where(depth == 0.0)] = 70.0
481
+ depth = np.reshape(depth, [H, W])
482
+
483
+ return depth
484
+
485
+ def vis_depth(depth, maxdepth=80.0, log_vis=True):
486
+ depth[depth<=0.0] = maxdepth
487
+ if log_vis:
488
+ depth = np.log(depth)
489
+ depth = np.clip(depth, 0, np.log(maxdepth))
490
+ else:
491
+ depth = np.clip(depth, 0, maxdepth)
492
+ depth = (depth*255.0).astype(np.uint8)
493
+ return depth
494
+
495
+ def preprocess_color(x):
496
+ return x.astype(np.float32) * 1./255 - 0.5
497
+
498
+ def convert_box_to_ref_T_obj(boxes):
499
+ shape = boxes.shape
500
+ boxes = boxes.reshape(-1,9)
501
+ rots = [eul2rotm(rx,ry,rz)
502
+ for rx,ry,rz in boxes[:,6:]]
503
+ rots = np.stack(rots,axis=0)
504
+ trans = boxes[:,:3]
505
+ ref_T_objs = [merge_rt(rot,tran)
506
+ for rot,tran in zip(rots,trans)]
507
+ ref_T_objs = np.stack(ref_T_objs,axis=0)
508
+ ref_T_objs = ref_T_objs.reshape(shape[:-1]+(4,4))
509
+ ref_T_objs = ref_T_objs.astype(np.float32)
510
+ return ref_T_objs
511
+
512
+ def get_rot_from_delta(delta, yaw_only=False):
513
+ dx = delta[:,0]
514
+ dy = delta[:,1]
515
+ dz = delta[:,2]
516
+
517
+ bot_hyp = np.sqrt(dz**2 + dx**2)
518
+ # top_hyp = np.sqrt(bot_hyp**2 + dy**2)
519
+
520
+ pitch = -np.arctan2(dy, bot_hyp)
521
+ yaw = np.arctan2(dz, dx)
522
+
523
+ if yaw_only:
524
+ rot = [eul2rotm(0,y,0) for y in yaw]
525
+ else:
526
+ rot = [eul2rotm(0,y,p) for (p,y) in zip(pitch,yaw)]
527
+
528
+ rot = np.stack(rot)
529
+ # rot is B x 3 x 3
530
+ return rot
531
+
532
+ def im2col(im, psize):
533
+ n_channels = 1 if len(im.shape) == 2 else im.shape[0]
534
+ (n_channels, rows, cols) = (1,) * (3 - len(im.shape)) + im.shape
535
+
536
+ im_pad = np.zeros((n_channels,
537
+ int(math.ceil(1.0 * rows / psize) * psize),
538
+ int(math.ceil(1.0 * cols / psize) * psize)))
539
+ im_pad[:, 0:rows, 0:cols] = im
540
+
541
+ final = np.zeros((im_pad.shape[1], im_pad.shape[2], n_channels,
542
+ psize, psize))
543
+ for c in np.arange(n_channels):
544
+ for x in np.arange(psize):
545
+ for y in np.arange(psize):
546
+ im_shift = np.vstack(
547
+ (im_pad[c, x:], im_pad[c, :x]))
548
+ im_shift = np.column_stack(
549
+ (im_shift[:, y:], im_shift[:, :y]))
550
+ final[x::psize, y::psize, c] = np.swapaxes(
551
+ im_shift.reshape(int(im_pad.shape[1] / psize), psize,
552
+ int(im_pad.shape[2] / psize), psize), 1, 2)
553
+
554
+ return np.squeeze(final[0:rows - psize + 1, 0:cols - psize + 1])
555
+
556
+ def filter_discontinuities(depth, filter_size=9, thresh=10):
557
+ H, W = list(depth.shape)
558
+
559
+ # Ensure that filter sizes are okay
560
+ assert filter_size % 2 == 1, "Can only use odd filter sizes."
561
+
562
+ # Compute discontinuities
563
+ offset = int((filter_size - 1) / 2)
564
+ patches = 1.0 * im2col(depth, filter_size)
565
+ mids = patches[:, :, offset, offset]
566
+ mins = np.min(patches, axis=(2, 3))
567
+ maxes = np.max(patches, axis=(2, 3))
568
+
569
+ discont = np.maximum(np.abs(mins - mids),
570
+ np.abs(maxes - mids))
571
+ mark = discont > thresh
572
+
573
+ # Account for offsets
574
+ final_mark = np.zeros((H, W), dtype=np.uint16)
575
+ final_mark[offset:offset + mark.shape[0],
576
+ offset:offset + mark.shape[1]] = mark
577
+
578
+ return depth * (1 - final_mark)
579
+
580
+ def argmax2d(tensor):
581
+ Y, X = list(tensor.shape)
582
+ # flatten the Tensor along the height and width axes
583
+ flat_tensor = tensor.reshape(-1)
584
+ # argmax of the flat tensor
585
+ argmax = np.argmax(flat_tensor)
586
+
587
+ # convert the indices into 2d coordinates
588
+ argmax_y = argmax // X # row
589
+ argmax_x = argmax % X # col
590
+
591
+ return argmax_y, argmax_x
592
+
593
+ def plot_traj_3d(traj):
594
+ # traj is S x 3
595
+
596
+ # print('traj', traj.shape)
597
+ S, C = list(traj.shape)
598
+ assert(C==3)
599
+
600
+ fig = plt.figure()
601
+ ax = fig.add_subplot(111, projection='3d')
602
+
603
+ colors = [plt.cm.RdYlBu(i) for i in np.linspace(0,1,S)]
604
+ # print('colors', colors)
605
+
606
+ xs = traj[:,0]
607
+ ys = -traj[:,1]
608
+ zs = traj[:,2]
609
+
610
+ ax.scatter(xs, zs, ys, s=30, c=colors, marker='o', alpha=1.0, edgecolors=(0,0,0))#, color=color_map[n])
611
+
612
+ ax.set_xlabel('X')
613
+ ax.set_ylabel('Z')
614
+ ax.set_zlabel('Y')
615
+
616
+ ax.set_xlim(0,1)
617
+ ax.set_ylim(0,1) # this is really Z
618
+ ax.set_zlim(-1,0) # this is really Y
619
+
620
+ buf = io.BytesIO()
621
+ plt.savefig(buf, format='png')
622
+ buf.seek(0)
623
+ image = np.array(Image.open(buf)) # H x W x 4
624
+ image = image[:,:,:3]
625
+
626
+ plt.close()
627
+ return image
628
+
629
+ def camera2pixels(xyz, pix_T_cam):
630
+ # xyz is shaped N x 3
631
+ # returns xy, shaped N x 2
632
+
633
+ fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
634
+ x, y, z = xyz[:,0], xyz[:,1], xyz[:,2]
635
+
636
+ EPS = 1e-4
637
+ z = np.clip(z, EPS, None)
638
+ x = (x*fx)/z + x0
639
+ y = (y*fy)/z + y0
640
+ xy = np.stack([x, y], axis=-1)
641
+ return xy
642
+
643
+ def make_colorwheel():
644
+ """
645
+ Generates a color wheel for optical flow visualization as presented in:
646
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
647
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
648
+
649
+ Code follows the original C++ source code of Daniel Scharstein.
650
+ Code follows the the Matlab source code of Deqing Sun.
651
+
652
+ Returns:
653
+ np.ndarray: Color wheel
654
+ """
655
+
656
+ RY = 15
657
+ YG = 6
658
+ GC = 4
659
+ CB = 11
660
+ BM = 13
661
+ MR = 6
662
+
663
+ ncols = RY + YG + GC + CB + BM + MR
664
+ colorwheel = np.zeros((ncols, 3))
665
+ col = 0
666
+
667
+ # RY
668
+ colorwheel[0:RY, 0] = 255
669
+ colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
670
+ col = col+RY
671
+ # YG
672
+ colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
673
+ colorwheel[col:col+YG, 1] = 255
674
+ col = col+YG
675
+ # GC
676
+ colorwheel[col:col+GC, 1] = 255
677
+ colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
678
+ col = col+GC
679
+ # CB
680
+ colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
681
+ colorwheel[col:col+CB, 2] = 255
682
+ col = col+CB
683
+ # BM
684
+ colorwheel[col:col+BM, 2] = 255
685
+ colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
686
+ col = col+BM
687
+ # MR
688
+ colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
689
+ colorwheel[col:col+MR, 0] = 255
690
+ return colorwheel
691
+
692
+ def flow_uv_to_colors(u, v, convert_to_bgr=False):
693
+ """
694
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
695
+
696
+ According to the C++ source code of Daniel Scharstein
697
+ According to the Matlab source code of Deqing Sun
698
+
699
+ Args:
700
+ u (np.ndarray): Input horizontal flow of shape [H,W]
701
+ v (np.ndarray): Input vertical flow of shape [H,W]
702
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
703
+
704
+ Returns:
705
+ np.ndarray: Flow visualization image of shape [H,W,3]
706
+ """
707
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
708
+ colorwheel = make_colorwheel() # shape [55x3]
709
+ ncols = colorwheel.shape[0]
710
+ rad = np.sqrt(np.square(u) + np.square(v))
711
+ a = np.arctan2(-v, -u)/np.pi
712
+ fk = (a+1) / 2*(ncols-1)
713
+ k0 = np.floor(fk).astype(np.int32)
714
+ k1 = k0 + 1
715
+ k1[k1 == ncols] = 0
716
+ f = fk - k0
717
+ for i in range(colorwheel.shape[1]):
718
+ tmp = colorwheel[:,i]
719
+ col0 = tmp[k0] / 255.0
720
+ col1 = tmp[k1] / 255.0
721
+ col = (1-f)*col0 + f*col1
722
+ idx = (rad <= 1)
723
+ col[idx] = 1 - rad[idx] * (1-col[idx])
724
+ col[~idx] = col[~idx] * 0.75 # out of range
725
+ # Note the 2-i => BGR instead of RGB
726
+ ch_idx = 2-i if convert_to_bgr else i
727
+ flow_image[:,:,ch_idx] = np.floor(255 * col)
728
+ return flow_image
729
+
730
+ def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
731
+ """
732
+ Expects a two dimensional flow image of shape.
733
+
734
+ Args:
735
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
736
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
737
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
738
+
739
+ Returns:
740
+ np.ndarray: Flow visualization image of shape [H,W,3]
741
+ """
742
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
743
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
744
+ if clip_flow is not None:
745
+ flow_uv = np.clip(flow_uv, -clip_flow, clip_flow) / clip_flow
746
+ # flow_uv = np.clamp(flow, -clip, clip)/clip
747
+
748
+ u = flow_uv[:,:,0]
749
+ v = flow_uv[:,:,1]
750
+ rad = np.sqrt(np.square(u) + np.square(v))
751
+ rad_max = np.max(rad)
752
+ epsilon = 1e-5
753
+ u = u / (rad_max + epsilon)
754
+ v = v / (rad_max + epsilon)
755
+ return flow_uv_to_colors(u, v, convert_to_bgr)
utils/samp.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import utils.basic
3
+ import torch.nn.functional as F
4
+
5
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
6
+ r"""Sample a tensor using bilinear interpolation
7
+
8
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
9
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
10
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
11
+ convention.
12
+
13
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
14
+ :math:`B` is the batch size, :math:`C` is the number of channels,
15
+ :math:`H` is the height of the image, and :math:`W` is the width of the
16
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
17
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
18
+
19
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
20
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
21
+ that in this case the order of the components is slightly different
22
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
23
+
24
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
25
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
26
+ left-most image pixel :math:`W-1` to the center of the right-most
27
+ pixel.
28
+
29
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
30
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
31
+ the left-most pixel :math:`W` to the right edge of the right-most
32
+ pixel.
33
+
34
+ Similar conventions apply to the :math:`y` for the range
35
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
36
+ :math:`[0,T-1]` and :math:`[0,T]`.
37
+
38
+ Args:
39
+ input (Tensor): batch of input images.
40
+ coords (Tensor): batch of coordinates.
41
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
42
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
43
+
44
+ Returns:
45
+ Tensor: sampled points.
46
+ """
47
+
48
+ sizes = input.shape[2:]
49
+
50
+ assert len(sizes) in [2, 3]
51
+
52
+ if len(sizes) == 3:
53
+ # t x y -> x y t to match dimensions T H W in grid_sample
54
+ coords = coords[..., [1, 2, 0]]
55
+
56
+ if align_corners:
57
+ coords = coords * torch.tensor(
58
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
59
+ )
60
+ else:
61
+ coords = coords * torch.tensor(
62
+ [2 / size for size in reversed(sizes)], device=coords.device
63
+ )
64
+
65
+ coords -= 1
66
+
67
+ return F.grid_sample(
68
+ input, coords, align_corners=align_corners, padding_mode=padding_mode
69
+ )
70
+
71
+
72
+ def sample_features4d(input, coords):
73
+ r"""Sample spatial features
74
+
75
+ `sample_features4d(input, coords)` samples the spatial features
76
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
77
+
78
+ The field is sampled at coordinates :attr:`coords` using bilinear
79
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
80
+ 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
81
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
82
+
83
+ The output tensor has one feature per point, and has shape :math:`(B,
84
+ R, C)`.
85
+
86
+ Args:
87
+ input (Tensor): spatial features.
88
+ coords (Tensor): points.
89
+
90
+ Returns:
91
+ Tensor: sampled features.
92
+ """
93
+
94
+ B, _, _, _ = input.shape
95
+
96
+ # B R 2 -> B R 1 2
97
+ coords = coords.unsqueeze(2)
98
+
99
+ # B C R 1
100
+ feats = bilinear_sampler(input, coords)
101
+
102
+ return feats.permute(0, 2, 1, 3).view(
103
+ B, -1, feats.shape[1] * feats.shape[3]
104
+ ) # B C R 1 -> B R C
105
+
106
+
107
+ def sample_features5d(input, coords):
108
+ r"""Sample spatio-temporal features
109
+
110
+ `sample_features5d(input, coords)` works in the same way as
111
+ :func:`sample_features4d` but for spatio-temporal features and points:
112
+ :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
113
+ a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
114
+ x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
115
+
116
+ Args:
117
+ input (Tensor): spatio-temporal features.
118
+ coords (Tensor): spatio-temporal points.
119
+
120
+ Returns:
121
+ Tensor: sampled features.
122
+ """
123
+
124
+ B, T, _, _, _ = input.shape
125
+
126
+ # B T C H W -> B C T H W
127
+ input = input.permute(0, 2, 1, 3, 4)
128
+
129
+ # B R1 R2 3 -> B R1 R2 1 3
130
+ coords = coords.unsqueeze(3)
131
+
132
+ # B C R1 R2 1
133
+ feats = bilinear_sampler(input, coords)
134
+
135
+ return feats.permute(0, 2, 3, 1, 4).view(
136
+ B, feats.shape[2], feats.shape[3], feats.shape[1]
137
+ ) # B C R1 R2 1 -> B R1 R2 C
138
+
139
+
140
+ def bilinear_sample2d(im, x, y, return_inbounds=False):
141
+ # x and y are each B, N
142
+ # output is B, C, N
143
+ B, C, H, W = list(im.shape)
144
+ N = list(x.shape)[1]
145
+
146
+ x = x.float()
147
+ y = y.float()
148
+ H_f = torch.tensor(H, dtype=torch.float32)
149
+ W_f = torch.tensor(W, dtype=torch.float32)
150
+
151
+ # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
152
+
153
+ max_y = (H_f - 1).int()
154
+ max_x = (W_f - 1).int()
155
+
156
+ x0 = torch.floor(x).int()
157
+ x1 = x0 + 1
158
+ y0 = torch.floor(y).int()
159
+ y1 = y0 + 1
160
+
161
+ x0_clip = torch.clamp(x0, 0, max_x)
162
+ x1_clip = torch.clamp(x1, 0, max_x)
163
+ y0_clip = torch.clamp(y0, 0, max_y)
164
+ y1_clip = torch.clamp(y1, 0, max_y)
165
+ dim2 = W
166
+ dim1 = W * H
167
+
168
+ base = torch.arange(0, B, dtype=torch.int64, device=x.device)*dim1
169
+ base = torch.reshape(base, [B, 1]).repeat([1, N])
170
+
171
+ base_y0 = base + y0_clip * dim2
172
+ base_y1 = base + y1_clip * dim2
173
+
174
+ idx_y0_x0 = base_y0 + x0_clip
175
+ idx_y0_x1 = base_y0 + x1_clip
176
+ idx_y1_x0 = base_y1 + x0_clip
177
+ idx_y1_x1 = base_y1 + x1_clip
178
+
179
+ # use the indices to lookup pixels in the flat image
180
+ # im is B x C x H x W
181
+ # move C out to last dim
182
+ im_flat = (im.permute(0, 2, 3, 1)).reshape(B*H*W, C)
183
+ i_y0_x0 = im_flat[idx_y0_x0.long()]
184
+ i_y0_x1 = im_flat[idx_y0_x1.long()]
185
+ i_y1_x0 = im_flat[idx_y1_x0.long()]
186
+ i_y1_x1 = im_flat[idx_y1_x1.long()]
187
+
188
+ # Finally calculate interpolated values.
189
+ x0_f = x0.float()
190
+ x1_f = x1.float()
191
+ y0_f = y0.float()
192
+ y1_f = y1.float()
193
+
194
+ w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
195
+ w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
196
+ w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
197
+ w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
198
+
199
+ output = w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + \
200
+ w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
201
+ # output is B*N x C
202
+ output = output.view(B, -1, C)
203
+ output = output.permute(0, 2, 1)
204
+ # output is B x C x N
205
+
206
+ if return_inbounds:
207
+ x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
208
+ y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
209
+ inbounds = (x_valid & y_valid).float()
210
+ inbounds = inbounds.reshape(B, N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
211
+ return output, inbounds
212
+
213
+ return output # B, C, N
utils/saveload.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import os
3
+ import torch
4
+
5
+ def save(ckpt_dir, module, optimizer, scheduler, global_step, keep_latest=2, model_name='model'):
6
+ pathlib.Path(ckpt_dir).mkdir(exist_ok=True, parents=True)
7
+ prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*pth' % model_name))
8
+ prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
9
+ if len(prev_ckpts) > keep_latest-1:
10
+ for f in prev_ckpts[keep_latest-1:]:
11
+ f.unlink()
12
+ save_path = '%s/%s-%09d.pth' % (ckpt_dir, model_name, global_step)
13
+ save_dict = {
14
+ "model": module.state_dict(),
15
+ "optimizer": optimizer.state_dict(),
16
+ "global_step": global_step,
17
+ }
18
+ if scheduler is not None:
19
+ save_dict['scheduler'] = scheduler.state_dict()
20
+ print(f"saving {save_path}")
21
+ torch.save(save_dict, save_path)
22
+ return False
23
+
24
+ def load(fabric, ckpt_path, model, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None, strict=True, verbose=True, weights_only=False):
25
+ if verbose:
26
+ print('reading ckpt from %s' % ckpt_path)
27
+ if not os.path.exists(ckpt_path):
28
+ print('...there is no full checkpoint in %s' % ckpt_path)
29
+ print('-- note this function no longer appends "saved_checkpoints/" before the ckpt_path --')
30
+ assert(False)
31
+ else:
32
+ if os.path.isfile(ckpt_path):
33
+ path = ckpt_path
34
+ print('...found checkpoint %s' % (path))
35
+ else:
36
+ prev_ckpts = list(pathlib.Path(ckpt_path).glob('%s-*pth' % model_name))
37
+ prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
38
+ if len(prev_ckpts):
39
+ path = prev_ckpts[0]
40
+ # e.g., './checkpoints/2Ai4_5e-4_base18_1539/model-000050000.pth'
41
+ # OR ./whatever.pth
42
+ step = int(str(path).split('-')[-1].split('.')[0])
43
+ if verbose:
44
+ print('...found checkpoint %s; (parsed step %d from path)' % (path, step))
45
+ else:
46
+ print('...there is no full checkpoint here!')
47
+ return 0
48
+ if fabric is not None:
49
+ checkpoint = fabric.load(path)
50
+ else:
51
+ checkpoint = torch.load(path, weights_only=weights_only)
52
+ if optimizer is not None:
53
+ optimizer.load_state_dict(checkpoint['optimizer'])
54
+ if scheduler is not None:
55
+ scheduler.load_state_dict(checkpoint['scheduler'])
56
+ assert ignore_load is None # not ready yet
57
+ if 'model' in checkpoint:
58
+ state_dict = checkpoint['model']
59
+ else:
60
+ state_dict = checkpoint
61
+ model.load_state_dict(state_dict, strict=strict)
62
+ return step
63
+
64
+
65
+