Spaces:
Running
on
Zero
Running
on
Zero
added basics
Browse files- nets/alltracker.py +588 -0
- nets/blocks.py +1304 -0
- utils/basic.py +144 -0
- utils/data.py +96 -0
- utils/improc.py +1103 -0
- utils/loss.py +220 -0
- utils/misc.py +100 -0
- utils/py.py +755 -0
- utils/samp.py +213 -0
- utils/saveload.py +65 -0
nets/alltracker.py
ADDED
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import utils.misc
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from nets.blocks import CNBlockConfig, ConvNeXt, conv1x1, RelUpdateBlock, InputPadder, CorrBlock, BasicEncoder
|
8 |
+
|
9 |
+
class Net(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
seqlen,
|
13 |
+
use_attn=True,
|
14 |
+
use_mixer=False,
|
15 |
+
use_conv=False,
|
16 |
+
use_convb=False,
|
17 |
+
use_basicencoder=False,
|
18 |
+
use_sinmotion=False,
|
19 |
+
use_relmotion=False,
|
20 |
+
use_sinrelmotion=False,
|
21 |
+
use_feats8=False,
|
22 |
+
no_time=False,
|
23 |
+
no_space=False,
|
24 |
+
no_split=False,
|
25 |
+
no_ctx=False,
|
26 |
+
full_split=False,
|
27 |
+
corr_levels=5,
|
28 |
+
corr_radius=4,
|
29 |
+
num_blocks=3,
|
30 |
+
dim=128,
|
31 |
+
hdim=128,
|
32 |
+
init_weights=True,
|
33 |
+
):
|
34 |
+
super(Net, self).__init__()
|
35 |
+
|
36 |
+
self.dim = dim
|
37 |
+
self.hdim = hdim
|
38 |
+
|
39 |
+
self.no_time = no_time
|
40 |
+
self.no_space = no_space
|
41 |
+
self.seqlen = seqlen
|
42 |
+
self.corr_levels = corr_levels
|
43 |
+
self.corr_radius = corr_radius
|
44 |
+
self.corr_channel = self.corr_levels * (self.corr_radius * 2 + 1) ** 2
|
45 |
+
self.num_blocks = num_blocks
|
46 |
+
|
47 |
+
self.use_feats8 = use_feats8
|
48 |
+
self.use_basicencoder = use_basicencoder
|
49 |
+
self.use_sinmotion = use_sinmotion
|
50 |
+
self.use_relmotion = use_relmotion
|
51 |
+
self.use_sinrelmotion = use_sinrelmotion
|
52 |
+
self.no_split = no_split
|
53 |
+
self.no_ctx = no_ctx
|
54 |
+
self.full_split = full_split
|
55 |
+
|
56 |
+
if use_basicencoder:
|
57 |
+
if self.full_split:
|
58 |
+
self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
|
59 |
+
self.cnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
|
60 |
+
else:
|
61 |
+
if self.no_split:
|
62 |
+
self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
|
63 |
+
else:
|
64 |
+
self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim*2, stride=8)
|
65 |
+
else:
|
66 |
+
block_setting = [
|
67 |
+
CNBlockConfig(96, 192, 3, True), # 4x
|
68 |
+
CNBlockConfig(192, 384, 3, False), # 8x
|
69 |
+
CNBlockConfig(384, None, 9, False), # 8x
|
70 |
+
]
|
71 |
+
self.cnn = ConvNeXt(block_setting, stochastic_depth_prob=0.0, init_weights=init_weights)
|
72 |
+
if self.no_split:
|
73 |
+
self.dot_conv = conv1x1(384, dim)
|
74 |
+
else:
|
75 |
+
self.dot_conv = conv1x1(384, dim*2)
|
76 |
+
|
77 |
+
self.upsample_weight = nn.Sequential(
|
78 |
+
# convex combination of 3x3 patches
|
79 |
+
nn.Conv2d(dim, dim * 2, 3, padding=1),
|
80 |
+
nn.ReLU(inplace=True),
|
81 |
+
nn.Conv2d(dim * 2, 64 * 9, 1, padding=0)
|
82 |
+
)
|
83 |
+
self.flow_head = nn.Sequential(
|
84 |
+
nn.Conv2d(dim, 2*dim, kernel_size=3, padding=1),
|
85 |
+
nn.ReLU(inplace=True),
|
86 |
+
nn.Conv2d(2*dim, 2, kernel_size=3, padding=1)
|
87 |
+
)
|
88 |
+
self.visconf_head = nn.Sequential(
|
89 |
+
nn.Conv2d(dim, 2*dim, kernel_size=3, padding=1),
|
90 |
+
nn.ReLU(inplace=True),
|
91 |
+
nn.Conv2d(2*dim, 2, kernel_size=3, padding=1)
|
92 |
+
)
|
93 |
+
|
94 |
+
if self.use_sinrelmotion:
|
95 |
+
self.pdim = 84 # 32*2
|
96 |
+
elif self.use_relmotion:
|
97 |
+
self.pdim = 4
|
98 |
+
elif self.use_sinmotion:
|
99 |
+
self.pdim = 42
|
100 |
+
else:
|
101 |
+
self.pdim = 2
|
102 |
+
|
103 |
+
self.update_block = RelUpdateBlock(self.corr_channel, self.num_blocks, cdim=dim, hdim=hdim, pdim=self.pdim,
|
104 |
+
use_attn=use_attn, use_mixer=use_mixer, use_conv=use_conv, use_convb=use_convb,
|
105 |
+
use_layer_scale=True, no_time=no_time, no_space=no_space,
|
106 |
+
no_ctx=no_ctx)
|
107 |
+
|
108 |
+
time_line = torch.linspace(0, seqlen-1, seqlen).reshape(1, seqlen, 1)
|
109 |
+
self.register_buffer("time_emb", utils.misc.get_1d_sincos_pos_embed_from_grid(self.dim, time_line[0])) # 1,S,C
|
110 |
+
|
111 |
+
|
112 |
+
def fetch_time_embed(self, t, dtype, is_training=False):
|
113 |
+
S = self.time_emb.shape[1]
|
114 |
+
if t == S:
|
115 |
+
return self.time_emb.to(dtype)
|
116 |
+
elif t==1:
|
117 |
+
if is_training:
|
118 |
+
ind = np.random.choice(S)
|
119 |
+
return self.time_emb[:,ind:ind+1].to(dtype)
|
120 |
+
else:
|
121 |
+
return self.time_emb[:,1:2].to(dtype)
|
122 |
+
else:
|
123 |
+
time_emb = self.time_emb.float()
|
124 |
+
time_emb = F.interpolate(time_emb.permute(0, 2, 1), size=t, mode="linear").permute(0, 2, 1)
|
125 |
+
return time_emb.to(dtype)
|
126 |
+
|
127 |
+
def coords_grid(self, batch, ht, wd, device, dtype):
|
128 |
+
coords = torch.meshgrid(torch.arange(ht, device=device, dtype=dtype), torch.arange(wd, device=device, dtype=dtype), indexing='ij')
|
129 |
+
coords = torch.stack(coords[::-1], dim=0)
|
130 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
131 |
+
|
132 |
+
def initialize_flow(self, img):
|
133 |
+
""" Flow is represented as difference between two coordinate grids flow = coords2 - coords1"""
|
134 |
+
N, C, H, W = img.shape
|
135 |
+
coords1 = self.coords_grid(N, H//8, W//8, device=img.device)
|
136 |
+
coords2 = self.coords_grid(N, H//8, W//8, device=img.device)
|
137 |
+
return coords1, coords2
|
138 |
+
|
139 |
+
def upsample_data(self, flow, mask):
|
140 |
+
""" Upsample [H/8, W/8, C] -> [H, W, C] using convex combination """
|
141 |
+
N, C, H, W = flow.shape
|
142 |
+
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
143 |
+
mask = torch.softmax(mask, dim=2)
|
144 |
+
|
145 |
+
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
146 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
147 |
+
|
148 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
149 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
150 |
+
|
151 |
+
return up_flow.reshape(N, 2, 8*H, 8*W).to(flow.dtype)
|
152 |
+
|
153 |
+
def get_T_padded_images(self, images, T, S, is_training, stride=None, pad=True):
|
154 |
+
B,T,C,H,W = images.shape
|
155 |
+
indices = None
|
156 |
+
if T > 2:
|
157 |
+
step = S // 2 if stride is None else stride
|
158 |
+
indices = []
|
159 |
+
start = 0
|
160 |
+
while start + S < T:
|
161 |
+
indices.append(start)
|
162 |
+
start += step
|
163 |
+
indices.append(start)
|
164 |
+
Tpad = indices[-1]+S-T
|
165 |
+
if pad:
|
166 |
+
if is_training:
|
167 |
+
assert Tpad == 0
|
168 |
+
else:
|
169 |
+
images = images.reshape(B,1,T,C*H*W)
|
170 |
+
if Tpad > 0:
|
171 |
+
padding_tensor = images[:,:,-1:,:].expand(B,1,Tpad,C*H*W)
|
172 |
+
images = torch.cat([images, padding_tensor], dim=2)
|
173 |
+
images = images.reshape(B,T+Tpad,C,H,W)
|
174 |
+
T = T+Tpad
|
175 |
+
else:
|
176 |
+
assert T == 2
|
177 |
+
return images, T, indices
|
178 |
+
|
179 |
+
def get_fmaps(self, images_, B, T, sw, is_training):
|
180 |
+
_, _, H_pad, W_pad = images_.shape # revised HW
|
181 |
+
|
182 |
+
C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
|
183 |
+
if self.no_split:
|
184 |
+
C = self.dim
|
185 |
+
|
186 |
+
fmaps_chunk_size = 32
|
187 |
+
if (not is_training) and (T > fmaps_chunk_size):
|
188 |
+
images = images_.reshape(B,T,3,H_pad,W_pad)
|
189 |
+
fmaps = []
|
190 |
+
for t in range(0, T, fmaps_chunk_size):
|
191 |
+
images_chunk = images[:, t : t + fmaps_chunk_size]
|
192 |
+
images_chunk = images_chunk.cuda()
|
193 |
+
if self.use_basicencoder:
|
194 |
+
if self.full_split:
|
195 |
+
fmaps_chunk1 = self.fnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
|
196 |
+
fmaps_chunk2 = self.cnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
|
197 |
+
fmaps_chunk = torch.cat([fmaps_chunk1, fmaps_chunk2], axis=1)
|
198 |
+
else:
|
199 |
+
fmaps_chunk = self.fnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
|
200 |
+
else:
|
201 |
+
fmaps_chunk = self.cnn(images_chunk.reshape(-1, 3, H_pad, W_pad))
|
202 |
+
if t==0 and sw is not None and sw.save_this:
|
203 |
+
sw.summ_feat('1_model/fmap_raw', fmaps_chunk[0:1])
|
204 |
+
fmaps_chunk = self.dot_conv(fmaps_chunk) # B*T,C,H8,W8
|
205 |
+
T_chunk = images_chunk.shape[1]
|
206 |
+
fmaps.append(fmaps_chunk.reshape(B, -1, C, H8, W8))
|
207 |
+
fmaps_ = torch.cat(fmaps, dim=1).reshape(-1, C, H8, W8)
|
208 |
+
else:
|
209 |
+
if not is_training:
|
210 |
+
# sometimes we need to move things to cuda here
|
211 |
+
images_ = images_.cuda()
|
212 |
+
if self.use_basicencoder:
|
213 |
+
if self.full_split:
|
214 |
+
fmaps1_ = self.fnet(images_)
|
215 |
+
fmaps2_ = self.cnet(images_)
|
216 |
+
fmaps_ = torch.cat([fmaps1_, fmaps2_], axis=1)
|
217 |
+
else:
|
218 |
+
fmaps_ = self.fnet(images_)
|
219 |
+
else:
|
220 |
+
fmaps_ = self.cnn(images_)
|
221 |
+
if sw is not None and sw.save_this:
|
222 |
+
sw.summ_feat('1_model/fmap_raw', fmaps_[0:1])
|
223 |
+
fmaps_ = self.dot_conv(fmaps_) # B*T,C,H8,W8
|
224 |
+
return fmaps_
|
225 |
+
|
226 |
+
def forward(self, images, iters=4, sw=None, is_training=False, stride=None):
|
227 |
+
B,T,C,H,W = images.shape
|
228 |
+
S = self.seqlen
|
229 |
+
device = images.device
|
230 |
+
dtype = images.dtype
|
231 |
+
|
232 |
+
print('images', images.shape)
|
233 |
+
|
234 |
+
# images are in [0,255]
|
235 |
+
mean = torch.as_tensor([0.485, 0.456, 0.406], device=device).reshape(1,1,3,1,1).to(images.dtype)
|
236 |
+
std = torch.as_tensor([0.229, 0.224, 0.225], device=device).reshape(1,1,3,1,1).to(images.dtype)
|
237 |
+
images = images / 255.0
|
238 |
+
images = (images - mean)/std
|
239 |
+
print("a0 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
240 |
+
|
241 |
+
T_bak = T
|
242 |
+
if stride is not None:
|
243 |
+
pad = False
|
244 |
+
else:
|
245 |
+
pad = True
|
246 |
+
images, T, indices = self.get_T_padded_images(images, T, S, is_training, stride=stride, pad=pad)
|
247 |
+
|
248 |
+
images = images.contiguous()
|
249 |
+
images_ = images.reshape(B*T,3,H,W)
|
250 |
+
padder = InputPadder(images_.shape)
|
251 |
+
images_ = padder.pad(images_)[0]
|
252 |
+
|
253 |
+
print("a1 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
254 |
+
|
255 |
+
_, _, H_pad, W_pad = images_.shape # revised HW
|
256 |
+
C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
|
257 |
+
C2 = C//2
|
258 |
+
if self.no_split:
|
259 |
+
C = self.dim
|
260 |
+
C2 = C
|
261 |
+
|
262 |
+
fmaps = self.get_fmaps(images_, B, T, sw, is_training).reshape(B,T,C,H8,W8)
|
263 |
+
device = fmaps.device
|
264 |
+
print("a2 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
265 |
+
|
266 |
+
fmap_anchor = fmaps[:,0]
|
267 |
+
|
268 |
+
if T<=2 or is_training:
|
269 |
+
# note: collecting preds can get expensive on a long video
|
270 |
+
all_flow_preds = []
|
271 |
+
all_visconf_preds = []
|
272 |
+
else:
|
273 |
+
all_flow_preds = None
|
274 |
+
all_visconf_preds = None
|
275 |
+
|
276 |
+
if T > 2: # multiframe tracking
|
277 |
+
|
278 |
+
# we will store our final outputs in these tensors
|
279 |
+
full_flows = torch.zeros((B,T,2,H,W), dtype=dtype, device=device)
|
280 |
+
full_visconfs = torch.zeros((B,T,2,H,W), dtype=dtype, device=device)
|
281 |
+
# 1/8 resolution
|
282 |
+
full_flows8 = torch.zeros((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
283 |
+
full_visconfs8 = torch.zeros((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
284 |
+
|
285 |
+
if self.use_feats8:
|
286 |
+
full_feats8 = torch.zeros((B,T,C2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
287 |
+
visits = np.zeros((T))
|
288 |
+
print("a3 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
289 |
+
|
290 |
+
for ii, ind in enumerate(indices):
|
291 |
+
ara = np.arange(ind,ind+S)
|
292 |
+
print('ara', ara)
|
293 |
+
if ii < len(indices)-1:
|
294 |
+
next_ind = indices[ii+1]
|
295 |
+
next_ara = np.arange(next_ind,next_ind+S)
|
296 |
+
|
297 |
+
# print("torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024), 'ara', ara)
|
298 |
+
fmaps2 = fmaps[:,ara]
|
299 |
+
flows8 = full_flows8[:,ara].reshape(B*(S),2,H_pad//8,W_pad//8).detach()
|
300 |
+
visconfs8 = full_visconfs8[:,ara].reshape(B*(S),2,H_pad//8,W_pad//8).detach()
|
301 |
+
|
302 |
+
if self.use_feats8:
|
303 |
+
if ind==0:
|
304 |
+
feats8 = None
|
305 |
+
else:
|
306 |
+
feats8 = full_feats8[:,ara].reshape(B*(S),C2,H_pad//8,W_pad//8).detach()
|
307 |
+
else:
|
308 |
+
feats8 = None
|
309 |
+
print("a4 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
310 |
+
|
311 |
+
flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
|
312 |
+
fmap_anchor, fmaps2, visconfs8, iters=iters, flowfeat=feats8, flows8=flows8,
|
313 |
+
is_training=is_training)
|
314 |
+
print("a5 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
315 |
+
|
316 |
+
unpad_flow_predictions = []
|
317 |
+
unpad_visconf_predictions = []
|
318 |
+
for i in range(len(flow_predictions)):
|
319 |
+
flow_predictions[i] = padder.unpad(flow_predictions[i])
|
320 |
+
unpad_flow_predictions.append(flow_predictions[i].reshape(B,S,2,H,W))
|
321 |
+
visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
|
322 |
+
unpad_visconf_predictions.append(visconf_predictions[i].reshape(B,S,2,H,W))
|
323 |
+
print("a6 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
324 |
+
|
325 |
+
full_flows[:,ara] = unpad_flow_predictions[-1].reshape(B,S,2,H,W)
|
326 |
+
full_flows8[:,ara] = flows8.reshape(B,S,2,H_pad//8,W_pad//8)
|
327 |
+
full_visconfs[:,ara] = unpad_visconf_predictions[-1].reshape(B,S,2,H,W)
|
328 |
+
full_visconfs8[:,ara] = visconfs8.reshape(B,S,2,H_pad//8,W_pad//8)
|
329 |
+
if self.use_feats8:
|
330 |
+
full_feats8[:,ara] = feats8.reshape(B,S,C2,H_pad//8,W_pad//8)
|
331 |
+
visits[ara] += 1
|
332 |
+
print("a7 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
333 |
+
|
334 |
+
if is_training:
|
335 |
+
all_flow_preds.append(unpad_flow_predictions)
|
336 |
+
all_visconf_preds.append(unpad_visconf_predictions)
|
337 |
+
else:
|
338 |
+
del unpad_flow_predictions
|
339 |
+
del unpad_visconf_predictions
|
340 |
+
|
341 |
+
# for the next iter, replace empty data with nearest available preds
|
342 |
+
invalid_idx = np.where(visits==0)[0]
|
343 |
+
valid_idx = np.where(visits>0)[0]
|
344 |
+
for idx in invalid_idx:
|
345 |
+
nearest = valid_idx[np.argmin(np.abs(valid_idx - idx))]
|
346 |
+
# print('replacing %d with %d' % (idx, nearest))
|
347 |
+
full_flows8[:,idx] = full_flows8[:,nearest]
|
348 |
+
full_visconfs8[:,idx] = full_visconfs8[:,nearest]
|
349 |
+
if self.use_feats8:
|
350 |
+
full_feats8[:,idx] = full_feats8[:,nearest]
|
351 |
+
print("a8 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
352 |
+
else: # flow
|
353 |
+
|
354 |
+
flows8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
355 |
+
visconfs8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
356 |
+
|
357 |
+
flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
|
358 |
+
fmap_anchor, fmaps[:,1:2], visconfs8, iters=iters, flowfeat=None, flows8=flows8,
|
359 |
+
is_training=is_training)
|
360 |
+
unpad_flow_predictions = []
|
361 |
+
unpad_visconf_predictions = []
|
362 |
+
for i in range(len(flow_predictions)):
|
363 |
+
flow_predictions[i] = padder.unpad(flow_predictions[i])
|
364 |
+
all_flow_preds.append(flow_predictions[i].reshape(B,2,H,W))
|
365 |
+
visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
|
366 |
+
all_visconf_preds.append(visconf_predictions[i].reshape(B,2,H,W))
|
367 |
+
full_flows = all_flow_preds[-1].reshape(B,2,H,W)
|
368 |
+
full_visconfs = all_visconf_preds[-1].reshape(B,2,H,W)
|
369 |
+
|
370 |
+
if (not is_training) and (T > 2):
|
371 |
+
full_flows = full_flows[:,:T_bak]
|
372 |
+
full_visconfs = full_visconfs[:,:T_bak]
|
373 |
+
print("a9 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
374 |
+
|
375 |
+
return full_flows, full_visconfs, all_flow_preds, all_visconf_preds
|
376 |
+
|
377 |
+
def forward_sliding(self, images, iters=4, sw=None, is_training=False, window_len=None, stride=None):
|
378 |
+
B,T,C,H,W = images.shape
|
379 |
+
S = self.seqlen if window_len is None else window_len
|
380 |
+
device = images.device
|
381 |
+
dtype = images.dtype
|
382 |
+
stride = S // 2 if stride is None else stride
|
383 |
+
|
384 |
+
T_bak = T
|
385 |
+
images, T, indices = self.get_T_padded_images(images, T, S, is_training, stride)
|
386 |
+
assert stride <= S // 2
|
387 |
+
|
388 |
+
images = images.contiguous()
|
389 |
+
images_ = images.reshape(B*T,3,H,W)
|
390 |
+
padder = InputPadder(images_.shape)
|
391 |
+
images_ = padder.pad(images_)[0]
|
392 |
+
|
393 |
+
_, _, H_pad, W_pad = images_.shape # revised HW
|
394 |
+
C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
|
395 |
+
C2 = C//2
|
396 |
+
if self.no_split:
|
397 |
+
C = self.dim
|
398 |
+
C2 = C
|
399 |
+
|
400 |
+
all_flow_preds = None
|
401 |
+
all_visconf_preds = None
|
402 |
+
|
403 |
+
if T<=2:
|
404 |
+
# note: collecting preds can get expensive on a long video
|
405 |
+
all_flow_preds = []
|
406 |
+
all_visconf_preds = []
|
407 |
+
|
408 |
+
fmaps = self.get_fmaps(images_, B, T, sw, is_training).reshape(B,T,C,H8,W8)
|
409 |
+
device = fmaps.device
|
410 |
+
|
411 |
+
flows8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
412 |
+
visconfs8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
413 |
+
|
414 |
+
fmap_anchor = fmaps[:,0]
|
415 |
+
|
416 |
+
flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
|
417 |
+
fmap_anchor, fmaps[:,1:2], visconfs8, iters=iters, flowfeat=None, flows8=flows8,
|
418 |
+
is_training=is_training)
|
419 |
+
unpad_flow_predictions = []
|
420 |
+
unpad_visconf_predictions = []
|
421 |
+
for i in range(len(flow_predictions)):
|
422 |
+
flow_predictions[i] = padder.unpad(flow_predictions[i])
|
423 |
+
all_flow_preds.append(flow_predictions[i].reshape(B,2,H,W))
|
424 |
+
visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
|
425 |
+
all_visconf_preds.append(visconf_predictions[i].reshape(B,2,H,W))
|
426 |
+
full_flows = all_flow_preds[-1].reshape(B,2,H,W).detach().cpu()
|
427 |
+
full_visconfs = all_visconf_preds[-1].reshape(B,2,H,W).detach().cpu()
|
428 |
+
|
429 |
+
return full_flows, full_visconfs, all_flow_preds, all_visconf_preds
|
430 |
+
|
431 |
+
assert T > 2 # multiframe tracking
|
432 |
+
|
433 |
+
if is_training:
|
434 |
+
all_flow_preds = []
|
435 |
+
all_visconf_preds = []
|
436 |
+
|
437 |
+
# we will store our final outputs in these cpu tensors
|
438 |
+
full_flows = torch.zeros((B,T,2,H,W), dtype=dtype, device='cpu')
|
439 |
+
full_visconfs = torch.zeros((B,T,2,H,W), dtype=dtype, device='cpu')
|
440 |
+
|
441 |
+
images_ = images_.reshape(B,T,3,H_pad,W_pad)
|
442 |
+
fmap_anchor = self.get_fmaps(images_[:,:1].reshape(-1,3,H_pad,W_pad), B, 1, sw, is_training).reshape(B,C,H8,W8)
|
443 |
+
device = fmap_anchor.device
|
444 |
+
full_visited = torch.zeros((T,), dtype=torch.bool, device=device)
|
445 |
+
|
446 |
+
for ii, ind in enumerate(indices):
|
447 |
+
ara = np.arange(ind,ind+S)
|
448 |
+
if ii == 0:
|
449 |
+
flows8 = torch.zeros((B,S,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
450 |
+
visconfs8 = torch.zeros((B,S,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
451 |
+
fmaps2 = self.get_fmaps(images_[:,ara].reshape(-1,3,H_pad,W_pad), B, S, sw, is_training).reshape(B,S,C,H8,W8)
|
452 |
+
else:
|
453 |
+
flows8 = torch.cat([flows8[:,stride:stride+S//2], flows8[:,stride+S//2-1:stride+S//2].repeat(1,S//2,1,1,1)], dim=1)
|
454 |
+
visconfs8 = torch.cat([visconfs8[:,stride:stride+S//2], visconfs8[:,stride+S//2-1:stride+S//2].repeat(1,S//2,1,1,1)], dim=1)
|
455 |
+
fmaps2 = torch.cat([fmaps2[:,stride:stride+S//2],
|
456 |
+
self.get_fmaps(images_[:,np.arange(ind+S//2,ind+S)].reshape(-1,3,H_pad,W_pad), B, S//2, sw, is_training).reshape(B,S//2,C,H8,W8)], dim=1)
|
457 |
+
|
458 |
+
flows8 = flows8.reshape(B*S,2,H_pad//8,W_pad//8).detach()
|
459 |
+
visconfs8 = visconfs8.reshape(B*S,2,H_pad//8,W_pad//8).detach()
|
460 |
+
|
461 |
+
flow_predictions, visconf_predictions, flows8, visconfs8, _ = self.forward_window(
|
462 |
+
fmap_anchor, fmaps2, visconfs8, iters=iters, flowfeat=None, flows8=flows8,
|
463 |
+
is_training=is_training)
|
464 |
+
|
465 |
+
unpad_flow_predictions = []
|
466 |
+
unpad_visconf_predictions = []
|
467 |
+
for i in range(len(flow_predictions)):
|
468 |
+
flow_predictions[i] = padder.unpad(flow_predictions[i])
|
469 |
+
unpad_flow_predictions.append(flow_predictions[i].reshape(B,S,2,H,W))
|
470 |
+
visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
|
471 |
+
unpad_visconf_predictions.append(visconf_predictions[i].reshape(B,S,2,H,W))
|
472 |
+
|
473 |
+
current_visiting = torch.zeros((T,), dtype=torch.bool, device=device)
|
474 |
+
current_visiting[ara] = True
|
475 |
+
|
476 |
+
to_fill = current_visiting & (~full_visited)
|
477 |
+
to_fill_sum = to_fill.sum().item()
|
478 |
+
full_flows[:,to_fill] = unpad_flow_predictions[-1].reshape(B,S,2,H,W)[:,-to_fill_sum:].detach().cpu()
|
479 |
+
full_visconfs[:,to_fill] = unpad_visconf_predictions[-1].reshape(B,S,2,H,W)[:,-to_fill_sum:].detach().cpu()
|
480 |
+
full_visited |= current_visiting
|
481 |
+
|
482 |
+
if is_training:
|
483 |
+
all_flow_preds.append(unpad_flow_predictions)
|
484 |
+
all_visconf_preds.append(unpad_visconf_predictions)
|
485 |
+
else:
|
486 |
+
del unpad_flow_predictions
|
487 |
+
del unpad_visconf_predictions
|
488 |
+
|
489 |
+
flows8 = flows8.reshape(B,S,2,H_pad//8,W_pad//8)
|
490 |
+
visconfs8 = visconfs8.reshape(B,S,2,H_pad//8,W_pad//8)
|
491 |
+
|
492 |
+
if not is_training:
|
493 |
+
full_flows = full_flows[:,:T_bak]
|
494 |
+
full_visconfs = full_visconfs[:,:T_bak]
|
495 |
+
|
496 |
+
return full_flows, full_visconfs, all_flow_preds, all_visconf_preds
|
497 |
+
|
498 |
+
def forward_window(self, fmap1_single, fmaps2, visconfs8, iters=None, flowfeat=None, flows8=None, sw=None, is_training=False):
|
499 |
+
B,S,C,H8,W8 = fmaps2.shape
|
500 |
+
device = fmaps2.device
|
501 |
+
dtype = fmaps2.dtype
|
502 |
+
|
503 |
+
flow_predictions = []
|
504 |
+
visconf_predictions = []
|
505 |
+
|
506 |
+
fmap1 = fmap1_single.unsqueeze(1).repeat(1,S,1,1,1) # B,S,C,H,W
|
507 |
+
fmap1 = fmap1.reshape(B*(S),C,H8,W8).contiguous()
|
508 |
+
|
509 |
+
fmap2 = fmaps2.reshape(B*(S),C,H8,W8).contiguous()
|
510 |
+
|
511 |
+
visconfs8 = visconfs8.reshape(B*(S),2,H8,W8).contiguous()
|
512 |
+
|
513 |
+
corr_fn = CorrBlock(fmap1, fmap2, self.corr_levels, self.corr_radius)
|
514 |
+
|
515 |
+
coords1 = self.coords_grid(B*(S), H8, W8, device=fmap1.device, dtype=dtype)
|
516 |
+
|
517 |
+
if self.no_split:
|
518 |
+
flowfeat, ctxfeat = fmap1.clone(), fmap1.clone()
|
519 |
+
else:
|
520 |
+
if flowfeat is not None:
|
521 |
+
_, ctxfeat = torch.split(fmap1, [self.dim, self.dim], dim=1)
|
522 |
+
else:
|
523 |
+
flowfeat, ctxfeat = torch.split(fmap1, [self.dim, self.dim], dim=1)
|
524 |
+
|
525 |
+
# add pos emb to ctxfeat (and not flowfeat), since ctxfeat is untouched across iters
|
526 |
+
time_emb = self.fetch_time_embed(S, ctxfeat.dtype, is_training).reshape(1,S,self.dim,1,1).repeat(B,1,1,1,1)
|
527 |
+
ctxfeat = ctxfeat + time_emb.reshape(B*S,self.dim,1,1)
|
528 |
+
|
529 |
+
if self.no_ctx:
|
530 |
+
flowfeat = flowfeat + time_emb.reshape(B*S,self.dim,1,1)
|
531 |
+
|
532 |
+
for itr in range(iters):
|
533 |
+
_, _, H8, W8 = flows8.shape
|
534 |
+
flows8 = flows8.detach()
|
535 |
+
coords2 = (coords1 + flows8).detach() # B*S,2,H,W
|
536 |
+
corr = corr_fn(coords2).to(dtype)
|
537 |
+
|
538 |
+
if self.use_relmotion or self.use_sinrelmotion:
|
539 |
+
coords_ = coords2.reshape(B,S,2,H8*W8).permute(0,1,3,2) # B,S,H8*W8,2
|
540 |
+
rel_coords_forward = coords_[:, :-1] - coords_[:, 1:]
|
541 |
+
rel_coords_backward = coords_[:, 1:] - coords_[:, :-1]
|
542 |
+
rel_coords_forward = torch.nn.functional.pad(
|
543 |
+
rel_coords_forward, (0, 0, 0, 0, 0, 1) # pad the 3rd-last dim (S) by (0,1)
|
544 |
+
)
|
545 |
+
rel_coords_backward = torch.nn.functional.pad(
|
546 |
+
rel_coords_backward, (0, 0, 0, 0, 1, 0) # pad the 3rd-last dim (S) by (1,0)
|
547 |
+
)
|
548 |
+
rel_coords = torch.cat([rel_coords_forward, rel_coords_backward], dim=-1) # B,S,H8*W8,4
|
549 |
+
|
550 |
+
if self.use_sinrelmotion:
|
551 |
+
rel_pos_emb_input = utils.misc.posenc(
|
552 |
+
rel_coords,
|
553 |
+
min_deg=0,
|
554 |
+
max_deg=10,
|
555 |
+
) # B,S,H*W,pdim
|
556 |
+
motion = rel_pos_emb_input.reshape(B*S,H8,W8,self.pdim).permute(0,3,1,2).to(dtype) # B*S,pdim,H8,W8
|
557 |
+
else:
|
558 |
+
motion = rel_coords.reshape(B*S,H8,W8,4).permute(0,3,1,2).to(dtype) # B*S,4,H8,W8
|
559 |
+
|
560 |
+
else:
|
561 |
+
if self.use_sinmotion:
|
562 |
+
pos_emb_input = utils.misc.posenc(
|
563 |
+
flows8.reshape(B,S,H8*W8,2),
|
564 |
+
min_deg=0,
|
565 |
+
max_deg=10,
|
566 |
+
) # B,S,H*W,pdim
|
567 |
+
motion = pos_emb_input.reshape(B*S,H8,W8,self.pdim).permute(0,3,1,2).to(dtype) # B*S,pdim,H8,W8
|
568 |
+
else:
|
569 |
+
motion = flows8
|
570 |
+
|
571 |
+
flowfeat = self.update_block(flowfeat, ctxfeat, visconfs8, corr, motion, S)
|
572 |
+
flow_update = self.flow_head(flowfeat)
|
573 |
+
visconf_update = self.visconf_head(flowfeat)
|
574 |
+
weight_update = .25 * self.upsample_weight(flowfeat)
|
575 |
+
flows8 = flows8 + flow_update
|
576 |
+
visconfs8 = visconfs8 + visconf_update
|
577 |
+
flow_up = self.upsample_data(flows8, weight_update)
|
578 |
+
visconf_up = self.upsample_data(visconfs8, weight_update)
|
579 |
+
if not is_training: # clear mem
|
580 |
+
flow_predictions = []
|
581 |
+
visconf_predictions = []
|
582 |
+
flow_predictions.append(flow_up)
|
583 |
+
visconf_predictions.append(visconf_up)
|
584 |
+
|
585 |
+
return flow_predictions, visconf_predictions, flows8, visconfs8, flowfeat
|
586 |
+
|
587 |
+
|
588 |
+
|
nets/blocks.py
ADDED
@@ -0,0 +1,1304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn, Tensor
|
5 |
+
from itertools import repeat
|
6 |
+
import collections
|
7 |
+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
|
8 |
+
from functools import partial
|
9 |
+
import einops
|
10 |
+
import math
|
11 |
+
from torchvision.ops.misc import Conv2dNormActivation, Permute
|
12 |
+
from torchvision.ops.stochastic_depth import StochasticDepth
|
13 |
+
|
14 |
+
def _ntuple(n):
|
15 |
+
def parse(x):
|
16 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
17 |
+
return tuple(x)
|
18 |
+
return tuple(repeat(x, n))
|
19 |
+
return parse
|
20 |
+
|
21 |
+
def exists(val):
|
22 |
+
return val is not None
|
23 |
+
|
24 |
+
def default(val, d):
|
25 |
+
return val if exists(val) else d
|
26 |
+
|
27 |
+
to_2tuple = _ntuple(2)
|
28 |
+
|
29 |
+
class InputPadder:
|
30 |
+
""" Pads images such that dimensions are divisible by a certain stride """
|
31 |
+
def __init__(self, dims, mode='sintel'):
|
32 |
+
self.ht, self.wd = dims[-2:]
|
33 |
+
pad_ht = (((self.ht // 64) + 1) * 64 - self.ht) % 64
|
34 |
+
pad_wd = (((self.wd // 64) + 1) * 64 - self.wd) % 64
|
35 |
+
if mode == 'sintel':
|
36 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
37 |
+
else:
|
38 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
39 |
+
|
40 |
+
def pad(self, *inputs):
|
41 |
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
42 |
+
|
43 |
+
def unpad(self, x):
|
44 |
+
ht, wd = x.shape[-2:]
|
45 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
46 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
47 |
+
|
48 |
+
def bilinear_sampler(
|
49 |
+
input, coords,
|
50 |
+
align_corners=True,
|
51 |
+
padding_mode="border",
|
52 |
+
normalize_coords=True):
|
53 |
+
# func from mattie (oct9)
|
54 |
+
if input.ndim not in [4, 5]:
|
55 |
+
raise ValueError("input must be 4D or 5D.")
|
56 |
+
|
57 |
+
if input.ndim == 4 and not coords.ndim == 4:
|
58 |
+
raise ValueError("input is 4D, but coords is not 4D.")
|
59 |
+
|
60 |
+
if input.ndim == 5 and not coords.ndim == 5:
|
61 |
+
raise ValueError("input is 5D, but coords is not 5D.")
|
62 |
+
|
63 |
+
if coords.ndim == 5:
|
64 |
+
coords = coords[..., [1, 2, 0]] # t x y -> x y t to match what grid_sample() expects.
|
65 |
+
|
66 |
+
if normalize_coords:
|
67 |
+
if align_corners:
|
68 |
+
# Normalize coordinates from [0, W/H - 1] to [-1, 1].
|
69 |
+
coords = (
|
70 |
+
coords
|
71 |
+
* torch.tensor([2 / max(size - 1, 1) for size in reversed(input.shape[2:])], device=coords.device)
|
72 |
+
- 1
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
# Normalize coordinates from [0, W/H] to [-1, 1].
|
76 |
+
coords = coords * torch.tensor([2 / size for size in reversed(input.shape[2:])], device=coords.device) - 1
|
77 |
+
|
78 |
+
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
79 |
+
|
80 |
+
|
81 |
+
class CorrBlock:
|
82 |
+
def __init__(self, fmap1, fmap2, corr_levels, corr_radius):
|
83 |
+
self.num_levels = corr_levels
|
84 |
+
self.radius = corr_radius
|
85 |
+
self.corr_pyramid = []
|
86 |
+
# all pairs correlation
|
87 |
+
for i in range(self.num_levels):
|
88 |
+
corr = CorrBlock.corr(fmap1, fmap2, 1)
|
89 |
+
batch, h1, w1, dim, h2, w2 = corr.shape
|
90 |
+
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
91 |
+
fmap2 = F.interpolate(fmap2, scale_factor=0.5, mode='area')
|
92 |
+
# print('corr', corr.shape)
|
93 |
+
self.corr_pyramid.append(corr)
|
94 |
+
|
95 |
+
def __call__(self, coords, dilation=None):
|
96 |
+
r = self.radius
|
97 |
+
coords = coords.permute(0, 2, 3, 1)
|
98 |
+
batch, h1, w1, _ = coords.shape
|
99 |
+
|
100 |
+
if dilation is None:
|
101 |
+
dilation = torch.ones(batch, 1, h1, w1, device=coords.device)
|
102 |
+
|
103 |
+
out_pyramid = []
|
104 |
+
for i in range(self.num_levels):
|
105 |
+
corr = self.corr_pyramid[i]
|
106 |
+
device = coords.device
|
107 |
+
dx = torch.linspace(-r, r, 2*r+1, device=device)
|
108 |
+
dy = torch.linspace(-r, r, 2*r+1, device=device)
|
109 |
+
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
|
110 |
+
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
111 |
+
delta_lvl = delta_lvl * dilation.view(batch * h1 * w1, 1, 1, 1)
|
112 |
+
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
113 |
+
coords_lvl = centroid_lvl + delta_lvl
|
114 |
+
corr = bilinear_sampler(corr, coords_lvl)
|
115 |
+
corr = corr.view(batch, h1, w1, -1)
|
116 |
+
out_pyramid.append(corr)
|
117 |
+
|
118 |
+
out = torch.cat(out_pyramid, dim=-1)
|
119 |
+
out = out.permute(0, 3, 1, 2).contiguous().float()
|
120 |
+
return out
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def corr(fmap1, fmap2, num_head):
|
124 |
+
batch, dim, h1, w1 = fmap1.shape
|
125 |
+
h2, w2 = fmap2.shape[2:]
|
126 |
+
fmap1 = fmap1.view(batch, num_head, dim // num_head, h1*w1)
|
127 |
+
fmap2 = fmap2.view(batch, num_head, dim // num_head, h2*w2)
|
128 |
+
corr = fmap1.transpose(2, 3) @ fmap2
|
129 |
+
corr = corr.reshape(batch, num_head, h1, w1, h2, w2).permute(0, 2, 3, 1, 4, 5)
|
130 |
+
return corr / torch.sqrt(torch.tensor(dim).float())
|
131 |
+
|
132 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
133 |
+
"""1x1 convolution without padding"""
|
134 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0)
|
135 |
+
|
136 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
137 |
+
"""3x3 convolution with padding"""
|
138 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1)
|
139 |
+
|
140 |
+
class LayerNorm2d(nn.LayerNorm):
|
141 |
+
def forward(self, x: Tensor) -> Tensor:
|
142 |
+
x = x.permute(0, 2, 3, 1)
|
143 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
144 |
+
x = x.permute(0, 3, 1, 2)
|
145 |
+
return x
|
146 |
+
|
147 |
+
class CNBlock1d(nn.Module):
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
dim,
|
151 |
+
output_dim,
|
152 |
+
layer_scale: float = 1e-6,
|
153 |
+
stochastic_depth_prob: float = 0,
|
154 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
155 |
+
dense=True,
|
156 |
+
use_attn=True,
|
157 |
+
use_mixer=False,
|
158 |
+
use_conv=False,
|
159 |
+
use_convb=False,
|
160 |
+
use_layer_scale=True,
|
161 |
+
) -> None:
|
162 |
+
super().__init__()
|
163 |
+
self.dense = dense
|
164 |
+
self.use_attn = use_attn
|
165 |
+
self.use_mixer = use_mixer
|
166 |
+
self.use_conv = use_conv
|
167 |
+
self.use_layer_scale = use_layer_scale
|
168 |
+
|
169 |
+
if use_attn:
|
170 |
+
assert not use_mixer
|
171 |
+
assert not use_conv
|
172 |
+
assert not use_convb
|
173 |
+
|
174 |
+
if norm_layer is None:
|
175 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
176 |
+
|
177 |
+
if use_attn:
|
178 |
+
num_heads = 8
|
179 |
+
self.block = AttnBlock(
|
180 |
+
hidden_size=dim,
|
181 |
+
num_heads=num_heads,
|
182 |
+
mlp_ratio=4,
|
183 |
+
attn_class=Attention,
|
184 |
+
)
|
185 |
+
elif use_mixer:
|
186 |
+
self.block = MLPMixerBlock(
|
187 |
+
S=16,
|
188 |
+
dim=dim,
|
189 |
+
depth=1,
|
190 |
+
expansion_factor=2,
|
191 |
+
)
|
192 |
+
elif use_conv:
|
193 |
+
self.block = nn.Sequential(
|
194 |
+
nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'),
|
195 |
+
Permute([0, 2, 1]),
|
196 |
+
norm_layer(dim),
|
197 |
+
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
|
198 |
+
nn.GELU(),
|
199 |
+
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
|
200 |
+
Permute([0, 2, 1]),
|
201 |
+
)
|
202 |
+
elif use_convb:
|
203 |
+
self.block = nn.Sequential(
|
204 |
+
nn.Conv1d(dim, dim, kernel_size=3, padding=1, bias=True, padding_mode='zeros'),
|
205 |
+
Permute([0, 2, 1]),
|
206 |
+
norm_layer(dim),
|
207 |
+
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
|
208 |
+
nn.GELU(),
|
209 |
+
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
|
210 |
+
Permute([0, 2, 1]),
|
211 |
+
)
|
212 |
+
else:
|
213 |
+
assert(False) # choose attn, mixer, or conv please
|
214 |
+
|
215 |
+
if self.use_layer_scale:
|
216 |
+
self.layer_scale = nn.Parameter(torch.ones(dim, 1) * layer_scale)
|
217 |
+
else:
|
218 |
+
self.layer_scale = 1.0
|
219 |
+
|
220 |
+
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
|
221 |
+
|
222 |
+
if output_dim != dim:
|
223 |
+
self.final = nn.Conv1d(dim, output_dim, kernel_size=1, padding=0)
|
224 |
+
else:
|
225 |
+
self.final = nn.Identity()
|
226 |
+
|
227 |
+
def forward(self, input, S=None):
|
228 |
+
if self.dense:
|
229 |
+
assert S is not None
|
230 |
+
BS,C,H,W = input.shape
|
231 |
+
B = BS//S
|
232 |
+
|
233 |
+
input = einops.rearrange(input, '(b s) c h w -> (b h w) c s', b=B, s=S, c=C, h=H, w=W)
|
234 |
+
|
235 |
+
if self.use_mixer or self.use_attn:
|
236 |
+
# mixer/transformer blocks want B,S,C
|
237 |
+
result = self.layer_scale * self.block(input.permute(0,2,1)).permute(0,2,1)
|
238 |
+
else:
|
239 |
+
result = self.layer_scale * self.block(input)
|
240 |
+
result = self.stochastic_depth(result)
|
241 |
+
result += input
|
242 |
+
result = self.final(result)
|
243 |
+
|
244 |
+
result = einops.rearrange(result, '(b h w) c s -> (b s) c h w', b=B, s=S, c=C, h=H, w=W)
|
245 |
+
else:
|
246 |
+
B,S,C = input.shape
|
247 |
+
|
248 |
+
if S<7:
|
249 |
+
return input
|
250 |
+
|
251 |
+
input = einops.rearrange(input, 'b s c -> b c s', b=B, s=S, c=C)
|
252 |
+
|
253 |
+
result = self.layer_scale * self.block(input)
|
254 |
+
result = self.stochastic_depth(result)
|
255 |
+
result += input
|
256 |
+
|
257 |
+
result = self.final(result)
|
258 |
+
|
259 |
+
result = einops.rearrange(result, 'b c s -> b s c', b=B, s=S, c=C)
|
260 |
+
|
261 |
+
return result
|
262 |
+
|
263 |
+
class CNBlock2d(nn.Module):
|
264 |
+
def __init__(
|
265 |
+
self,
|
266 |
+
dim,
|
267 |
+
output_dim,
|
268 |
+
layer_scale: float = 1e-6,
|
269 |
+
stochastic_depth_prob: float = 0,
|
270 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
271 |
+
use_layer_scale=True,
|
272 |
+
) -> None:
|
273 |
+
super().__init__()
|
274 |
+
self.use_layer_scale = use_layer_scale
|
275 |
+
if norm_layer is None:
|
276 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
277 |
+
|
278 |
+
self.block = nn.Sequential(
|
279 |
+
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'),
|
280 |
+
Permute([0, 2, 3, 1]),
|
281 |
+
norm_layer(dim),
|
282 |
+
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
|
283 |
+
nn.GELU(),
|
284 |
+
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
|
285 |
+
Permute([0, 3, 1, 2]),
|
286 |
+
)
|
287 |
+
if self.use_layer_scale:
|
288 |
+
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
|
289 |
+
else:
|
290 |
+
self.layer_scale = 1.0
|
291 |
+
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
|
292 |
+
|
293 |
+
if output_dim != dim:
|
294 |
+
self.final = nn.Conv2d(dim, output_dim, kernel_size=1, padding=0)
|
295 |
+
else:
|
296 |
+
self.final = nn.Identity()
|
297 |
+
|
298 |
+
def forward(self, input, S=None):
|
299 |
+
result = self.layer_scale * self.block(input)
|
300 |
+
result = self.stochastic_depth(result)
|
301 |
+
result += input
|
302 |
+
result = self.final(result)
|
303 |
+
return result
|
304 |
+
|
305 |
+
class CNBlockConfig:
|
306 |
+
# Stores information listed at Section 3 of the ConvNeXt paper
|
307 |
+
def __init__(
|
308 |
+
self,
|
309 |
+
input_channels: int,
|
310 |
+
out_channels: Optional[int],
|
311 |
+
num_layers: int,
|
312 |
+
downsample: bool,
|
313 |
+
) -> None:
|
314 |
+
self.input_channels = input_channels
|
315 |
+
self.out_channels = out_channels
|
316 |
+
self.num_layers = num_layers
|
317 |
+
self.downsample = downsample
|
318 |
+
|
319 |
+
def __repr__(self) -> str:
|
320 |
+
s = self.__class__.__name__ + "("
|
321 |
+
s += "input_channels={input_channels}"
|
322 |
+
s += ", out_channels={out_channels}"
|
323 |
+
s += ", num_layers={num_layers}"
|
324 |
+
s += ", downsample={downsample}"
|
325 |
+
s += ")"
|
326 |
+
return s.format(**self.__dict__)
|
327 |
+
|
328 |
+
class ConvNeXt(nn.Module):
|
329 |
+
def __init__(
|
330 |
+
self,
|
331 |
+
block_setting: List[CNBlockConfig],
|
332 |
+
stochastic_depth_prob: float = 0.0,
|
333 |
+
layer_scale: float = 1e-6,
|
334 |
+
num_classes: int = 1000,
|
335 |
+
block: Optional[Callable[..., nn.Module]] = None,
|
336 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
337 |
+
init_weights=True):
|
338 |
+
super().__init__()
|
339 |
+
|
340 |
+
self.init_weights = init_weights
|
341 |
+
|
342 |
+
if not block_setting:
|
343 |
+
raise ValueError("The block_setting should not be empty")
|
344 |
+
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
|
345 |
+
raise TypeError("The block_setting should be List[CNBlockConfig]")
|
346 |
+
|
347 |
+
if block is None:
|
348 |
+
block = CNBlock2d
|
349 |
+
|
350 |
+
if norm_layer is None:
|
351 |
+
norm_layer = partial(LayerNorm2d, eps=1e-6)
|
352 |
+
|
353 |
+
layers: List[nn.Module] = []
|
354 |
+
|
355 |
+
# Stem
|
356 |
+
firstconv_output_channels = block_setting[0].input_channels
|
357 |
+
layers.append(
|
358 |
+
Conv2dNormActivation(
|
359 |
+
3,
|
360 |
+
firstconv_output_channels,
|
361 |
+
kernel_size=4,
|
362 |
+
stride=4,
|
363 |
+
padding=0,
|
364 |
+
norm_layer=norm_layer,
|
365 |
+
activation_layer=None,
|
366 |
+
bias=True,
|
367 |
+
)
|
368 |
+
)
|
369 |
+
|
370 |
+
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
|
371 |
+
stage_block_id = 0
|
372 |
+
for cnf in block_setting:
|
373 |
+
# Bottlenecks
|
374 |
+
stage: List[nn.Module] = []
|
375 |
+
for _ in range(cnf.num_layers):
|
376 |
+
# adjust stochastic depth probability based on the depth of the stage block
|
377 |
+
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
|
378 |
+
stage.append(block(cnf.input_channels, cnf.input_channels, layer_scale, sd_prob))
|
379 |
+
stage_block_id += 1
|
380 |
+
layers.append(nn.Sequential(*stage))
|
381 |
+
if cnf.out_channels is not None:
|
382 |
+
if cnf.downsample:
|
383 |
+
layers.append(
|
384 |
+
nn.Sequential(
|
385 |
+
norm_layer(cnf.input_channels),
|
386 |
+
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
|
387 |
+
)
|
388 |
+
)
|
389 |
+
else:
|
390 |
+
# we convert the 2x2 downsampling layer into a 3x3 with dilation2 and replicate padding.
|
391 |
+
# replicate padding compensates for the fact that this kernel never saw zero-padding.
|
392 |
+
layers.append(
|
393 |
+
nn.Sequential(
|
394 |
+
norm_layer(cnf.input_channels),
|
395 |
+
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=3, stride=1, padding=2, dilation=2, padding_mode='zeros'),
|
396 |
+
)
|
397 |
+
)
|
398 |
+
|
399 |
+
self.features = nn.Sequential(*layers)
|
400 |
+
|
401 |
+
# self.final_conv = conv1x1(block_setting[-1].input_channels, output_dim)
|
402 |
+
|
403 |
+
for m in self.modules():
|
404 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
405 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
406 |
+
if m.bias is not None:
|
407 |
+
nn.init.zeros_(m.bias)
|
408 |
+
|
409 |
+
if self.init_weights:
|
410 |
+
from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights
|
411 |
+
pretrained_dict = convnext_tiny(weights=ConvNeXt_Tiny_Weights.DEFAULT).state_dict()
|
412 |
+
# from torchvision.models import convnext_base, ConvNeXt_Base_Weights
|
413 |
+
# pretrained_dict = convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT).state_dict()
|
414 |
+
model_dict = self.state_dict()
|
415 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
416 |
+
|
417 |
+
for k, v in pretrained_dict.items():
|
418 |
+
if k == 'features.4.1.weight': # this is the layer normally in charge of 2x2 downsampling
|
419 |
+
# convert to 3x3 filter
|
420 |
+
pretrained_dict[k] = F.interpolate(v, (3, 3), mode='bicubic', align_corners=True) * (4/9.0)
|
421 |
+
|
422 |
+
model_dict.update(pretrained_dict)
|
423 |
+
self.load_state_dict(model_dict, strict=False)
|
424 |
+
|
425 |
+
|
426 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
427 |
+
x = self.features(x)
|
428 |
+
# x = self.final_conv(x)
|
429 |
+
return x
|
430 |
+
|
431 |
+
def forward(self, x: Tensor) -> Tensor:
|
432 |
+
return self._forward_impl(x)
|
433 |
+
|
434 |
+
class Mlp(nn.Module):
|
435 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
436 |
+
|
437 |
+
def __init__(
|
438 |
+
self,
|
439 |
+
in_features,
|
440 |
+
hidden_features=None,
|
441 |
+
out_features=None,
|
442 |
+
act_layer=nn.GELU,
|
443 |
+
norm_layer=None,
|
444 |
+
bias=True,
|
445 |
+
drop=0.0,
|
446 |
+
use_conv=False,
|
447 |
+
):
|
448 |
+
super().__init__()
|
449 |
+
out_features = out_features or in_features
|
450 |
+
hidden_features = hidden_features or in_features
|
451 |
+
bias = to_2tuple(bias)
|
452 |
+
drop_probs = to_2tuple(drop)
|
453 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
454 |
+
|
455 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
456 |
+
self.act = act_layer()
|
457 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
458 |
+
self.norm = (
|
459 |
+
norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
460 |
+
)
|
461 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
462 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
463 |
+
|
464 |
+
def forward(self, x):
|
465 |
+
x = self.fc1(x)
|
466 |
+
x = self.act(x)
|
467 |
+
x = self.drop1(x)
|
468 |
+
x = self.fc2(x)
|
469 |
+
x = self.drop2(x)
|
470 |
+
return x
|
471 |
+
|
472 |
+
class Attention(nn.Module):
|
473 |
+
def __init__(
|
474 |
+
self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False
|
475 |
+
):
|
476 |
+
super().__init__()
|
477 |
+
inner_dim = dim_head * num_heads
|
478 |
+
context_dim = default(context_dim, query_dim)
|
479 |
+
self.scale = dim_head**-0.5
|
480 |
+
self.heads = num_heads
|
481 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
482 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
483 |
+
self.to_out = nn.Linear(inner_dim, query_dim)
|
484 |
+
|
485 |
+
def forward(self, x, context=None, attn_bias=None):
|
486 |
+
B, N1, C = x.shape
|
487 |
+
H = self.heads
|
488 |
+
q = self.to_q(x)
|
489 |
+
context = default(context, x)
|
490 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
491 |
+
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))
|
492 |
+
x = F.scaled_dot_product_attention(q, k, v) # scale default is already dim^-0.5
|
493 |
+
x = einops.rearrange(x, 'b h n d -> b n (h d)')
|
494 |
+
return self.to_out(x)
|
495 |
+
|
496 |
+
class CrossAttnBlock(nn.Module):
|
497 |
+
def __init__(
|
498 |
+
self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs
|
499 |
+
):
|
500 |
+
super().__init__()
|
501 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
502 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
503 |
+
self.cross_attn = Attention(
|
504 |
+
hidden_size,
|
505 |
+
context_dim=context_dim,
|
506 |
+
num_heads=num_heads,
|
507 |
+
qkv_bias=True,
|
508 |
+
**block_kwargs
|
509 |
+
)
|
510 |
+
|
511 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
512 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
513 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
514 |
+
self.mlp = Mlp(
|
515 |
+
in_features=hidden_size,
|
516 |
+
hidden_features=mlp_hidden_dim,
|
517 |
+
act_layer=approx_gelu,
|
518 |
+
drop=0,
|
519 |
+
)
|
520 |
+
|
521 |
+
def forward(self, x, context, mask=None):
|
522 |
+
attn_bias = None
|
523 |
+
if mask is not None:
|
524 |
+
if mask.shape[1] == x.shape[1]:
|
525 |
+
mask = mask[:, None, :, None].expand(
|
526 |
+
-1, self.cross_attn.heads, -1, context.shape[1]
|
527 |
+
)
|
528 |
+
else:
|
529 |
+
mask = mask[:, None, None].expand(
|
530 |
+
-1, self.cross_attn.heads, x.shape[1], -1
|
531 |
+
)
|
532 |
+
|
533 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
534 |
+
attn_bias = (~mask) * max_neg_value
|
535 |
+
x = x + self.cross_attn(
|
536 |
+
self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
|
537 |
+
)
|
538 |
+
x = x + self.mlp(self.norm2(x))
|
539 |
+
return x
|
540 |
+
|
541 |
+
class AttnBlock(nn.Module):
|
542 |
+
def __init__(
|
543 |
+
self,
|
544 |
+
hidden_size,
|
545 |
+
num_heads,
|
546 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
547 |
+
mlp_ratio=4.0,
|
548 |
+
**block_kwargs
|
549 |
+
):
|
550 |
+
super().__init__()
|
551 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
552 |
+
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, dim_head=hidden_size//num_heads)
|
553 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
554 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
555 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
556 |
+
self.mlp = Mlp(
|
557 |
+
in_features=hidden_size,
|
558 |
+
hidden_features=mlp_hidden_dim,
|
559 |
+
act_layer=approx_gelu,
|
560 |
+
drop=0,
|
561 |
+
)
|
562 |
+
|
563 |
+
def forward(self, x, mask=None):
|
564 |
+
attn_bias = mask
|
565 |
+
if mask is not None:
|
566 |
+
mask = (
|
567 |
+
(mask[:, None] * mask[:, :, None])
|
568 |
+
.unsqueeze(1)
|
569 |
+
.expand(-1, self.attn.num_heads, -1, -1)
|
570 |
+
)
|
571 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
572 |
+
attn_bias = (~mask) * max_neg_value
|
573 |
+
|
574 |
+
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
|
575 |
+
x = x + self.mlp(self.norm2(x))
|
576 |
+
return x
|
577 |
+
|
578 |
+
|
579 |
+
class ResidualBlock(nn.Module):
|
580 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
581 |
+
super(ResidualBlock, self).__init__()
|
582 |
+
|
583 |
+
self.conv1 = nn.Conv2d(
|
584 |
+
in_planes,
|
585 |
+
planes,
|
586 |
+
kernel_size=3,
|
587 |
+
padding=1,
|
588 |
+
stride=stride,
|
589 |
+
padding_mode="zeros",
|
590 |
+
)
|
591 |
+
self.conv2 = nn.Conv2d(
|
592 |
+
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
|
593 |
+
)
|
594 |
+
self.relu = nn.ReLU(inplace=True)
|
595 |
+
|
596 |
+
num_groups = planes // 8
|
597 |
+
|
598 |
+
if norm_fn == "group":
|
599 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
600 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
601 |
+
if not stride == 1:
|
602 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
603 |
+
|
604 |
+
elif norm_fn == "batch":
|
605 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
606 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
607 |
+
if not stride == 1:
|
608 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
609 |
+
|
610 |
+
elif norm_fn == "instance":
|
611 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
612 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
613 |
+
if not stride == 1:
|
614 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
615 |
+
|
616 |
+
elif norm_fn == "none":
|
617 |
+
self.norm1 = nn.Sequential()
|
618 |
+
self.norm2 = nn.Sequential()
|
619 |
+
if not stride == 1:
|
620 |
+
self.norm3 = nn.Sequential()
|
621 |
+
|
622 |
+
if stride == 1:
|
623 |
+
self.downsample = None
|
624 |
+
|
625 |
+
else:
|
626 |
+
self.downsample = nn.Sequential(
|
627 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
628 |
+
)
|
629 |
+
|
630 |
+
def forward(self, x):
|
631 |
+
y = x
|
632 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
633 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
634 |
+
|
635 |
+
if self.downsample is not None:
|
636 |
+
x = self.downsample(x)
|
637 |
+
|
638 |
+
return self.relu(x + y)
|
639 |
+
|
640 |
+
|
641 |
+
class BasicEncoder(nn.Module):
|
642 |
+
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
643 |
+
super(BasicEncoder, self).__init__()
|
644 |
+
self.stride = stride
|
645 |
+
self.norm_fn = "instance"
|
646 |
+
self.in_planes = output_dim // 2
|
647 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
648 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
649 |
+
|
650 |
+
self.conv1 = nn.Conv2d(
|
651 |
+
input_dim,
|
652 |
+
self.in_planes,
|
653 |
+
kernel_size=7,
|
654 |
+
stride=2,
|
655 |
+
padding=3,
|
656 |
+
padding_mode="zeros",
|
657 |
+
)
|
658 |
+
self.relu1 = nn.ReLU(inplace=True)
|
659 |
+
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
660 |
+
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
661 |
+
self.layer3 = self._make_layer(output_dim, stride=2)
|
662 |
+
self.layer4 = self._make_layer(output_dim, stride=2)
|
663 |
+
|
664 |
+
self.conv2 = nn.Conv2d(
|
665 |
+
output_dim * 3 + output_dim // 4,
|
666 |
+
output_dim * 2,
|
667 |
+
kernel_size=3,
|
668 |
+
padding=1,
|
669 |
+
padding_mode="zeros",
|
670 |
+
)
|
671 |
+
self.relu2 = nn.ReLU(inplace=True)
|
672 |
+
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
673 |
+
for m in self.modules():
|
674 |
+
if isinstance(m, nn.Conv2d):
|
675 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
676 |
+
elif isinstance(m, (nn.InstanceNorm2d)):
|
677 |
+
if m.weight is not None:
|
678 |
+
nn.init.constant_(m.weight, 1)
|
679 |
+
if m.bias is not None:
|
680 |
+
nn.init.constant_(m.bias, 0)
|
681 |
+
|
682 |
+
def _make_layer(self, dim, stride=1):
|
683 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
684 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
685 |
+
layers = (layer1, layer2)
|
686 |
+
|
687 |
+
self.in_planes = dim
|
688 |
+
return nn.Sequential(*layers)
|
689 |
+
|
690 |
+
def forward(self, x):
|
691 |
+
_, _, H, W = x.shape
|
692 |
+
|
693 |
+
x = self.conv1(x)
|
694 |
+
x = self.norm1(x)
|
695 |
+
x = self.relu1(x)
|
696 |
+
|
697 |
+
a = self.layer1(x)
|
698 |
+
b = self.layer2(a)
|
699 |
+
c = self.layer3(b)
|
700 |
+
d = self.layer4(c)
|
701 |
+
|
702 |
+
def _bilinear_intepolate(x):
|
703 |
+
return F.interpolate(
|
704 |
+
x,
|
705 |
+
(H // self.stride, W // self.stride),
|
706 |
+
mode="bilinear",
|
707 |
+
align_corners=True,
|
708 |
+
)
|
709 |
+
|
710 |
+
a = _bilinear_intepolate(a)
|
711 |
+
b = _bilinear_intepolate(b)
|
712 |
+
c = _bilinear_intepolate(c)
|
713 |
+
d = _bilinear_intepolate(d)
|
714 |
+
|
715 |
+
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
716 |
+
x = self.norm2(x)
|
717 |
+
x = self.relu2(x)
|
718 |
+
x = self.conv3(x)
|
719 |
+
return x
|
720 |
+
|
721 |
+
class EfficientUpdateFormer(nn.Module):
|
722 |
+
"""
|
723 |
+
Transformer model that updates track estimates.
|
724 |
+
"""
|
725 |
+
|
726 |
+
def __init__(
|
727 |
+
self,
|
728 |
+
space_depth=6,
|
729 |
+
time_depth=6,
|
730 |
+
input_dim=320,
|
731 |
+
hidden_size=384,
|
732 |
+
num_heads=8,
|
733 |
+
output_dim=130,
|
734 |
+
mlp_ratio=4.0,
|
735 |
+
num_virtual_tracks=64,
|
736 |
+
add_space_attn=True,
|
737 |
+
linear_layer_for_vis_conf=False,
|
738 |
+
use_time_conv=False,
|
739 |
+
use_time_mixer=False,
|
740 |
+
):
|
741 |
+
super().__init__()
|
742 |
+
self.out_channels = 2
|
743 |
+
self.num_heads = num_heads
|
744 |
+
self.hidden_size = hidden_size
|
745 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
746 |
+
if linear_layer_for_vis_conf:
|
747 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
|
748 |
+
self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
|
749 |
+
else:
|
750 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
751 |
+
self.num_virtual_tracks = num_virtual_tracks
|
752 |
+
self.virual_tracks = nn.Parameter(
|
753 |
+
torch.randn(1, num_virtual_tracks, 1, hidden_size)
|
754 |
+
)
|
755 |
+
self.add_space_attn = add_space_attn
|
756 |
+
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
|
757 |
+
|
758 |
+
if use_time_conv:
|
759 |
+
self.time_blocks = nn.ModuleList(
|
760 |
+
[
|
761 |
+
CNBlock1d(hidden_size, hidden_size, dense=False)
|
762 |
+
for _ in range(time_depth)
|
763 |
+
]
|
764 |
+
)
|
765 |
+
elif use_time_mixer:
|
766 |
+
self.time_blocks = nn.ModuleList(
|
767 |
+
[
|
768 |
+
MLPMixerBlock(
|
769 |
+
S=16,
|
770 |
+
dim=hidden_size,
|
771 |
+
depth=1,
|
772 |
+
)
|
773 |
+
for _ in range(time_depth)
|
774 |
+
]
|
775 |
+
)
|
776 |
+
else:
|
777 |
+
self.time_blocks = nn.ModuleList(
|
778 |
+
[
|
779 |
+
AttnBlock(
|
780 |
+
hidden_size,
|
781 |
+
num_heads,
|
782 |
+
mlp_ratio=mlp_ratio,
|
783 |
+
attn_class=Attention,
|
784 |
+
)
|
785 |
+
for _ in range(time_depth)
|
786 |
+
]
|
787 |
+
)
|
788 |
+
|
789 |
+
if add_space_attn:
|
790 |
+
self.space_virtual_blocks = nn.ModuleList(
|
791 |
+
[
|
792 |
+
AttnBlock(
|
793 |
+
hidden_size,
|
794 |
+
num_heads,
|
795 |
+
mlp_ratio=mlp_ratio,
|
796 |
+
attn_class=Attention,
|
797 |
+
)
|
798 |
+
for _ in range(space_depth)
|
799 |
+
]
|
800 |
+
)
|
801 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
802 |
+
[
|
803 |
+
CrossAttnBlock(
|
804 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
805 |
+
)
|
806 |
+
for _ in range(space_depth)
|
807 |
+
]
|
808 |
+
)
|
809 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
810 |
+
[
|
811 |
+
CrossAttnBlock(
|
812 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
813 |
+
)
|
814 |
+
for _ in range(space_depth)
|
815 |
+
]
|
816 |
+
)
|
817 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
818 |
+
self.initialize_weights()
|
819 |
+
|
820 |
+
def initialize_weights(self):
|
821 |
+
def _basic_init(module):
|
822 |
+
if isinstance(module, nn.Linear):
|
823 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
824 |
+
if module.bias is not None:
|
825 |
+
nn.init.constant_(module.bias, 0)
|
826 |
+
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
|
827 |
+
if self.linear_layer_for_vis_conf:
|
828 |
+
torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
|
829 |
+
|
830 |
+
def _trunc_init(module):
|
831 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
832 |
+
if isinstance(module, nn.Linear):
|
833 |
+
torch.nn.init.trunc_normal_(module.weight, std=0.02)
|
834 |
+
if module.bias is not None:
|
835 |
+
nn.init.zeros_(module.bias)
|
836 |
+
|
837 |
+
self.apply(_basic_init)
|
838 |
+
|
839 |
+
def forward(self, input_tensor, mask=None, add_space_attn=True):
|
840 |
+
tokens = self.input_transform(input_tensor)
|
841 |
+
|
842 |
+
B, _, T, _ = tokens.shape
|
843 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
844 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
845 |
+
|
846 |
+
_, N, _, _ = tokens.shape
|
847 |
+
j = 0
|
848 |
+
layers = []
|
849 |
+
for i in range(len(self.time_blocks)):
|
850 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
851 |
+
time_tokens = self.time_blocks[i](time_tokens)
|
852 |
+
|
853 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
854 |
+
if (
|
855 |
+
add_space_attn
|
856 |
+
and hasattr(self, "space_virtual_blocks")
|
857 |
+
and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
|
858 |
+
):
|
859 |
+
space_tokens = (
|
860 |
+
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
|
861 |
+
) # B N T C -> (B T) N C
|
862 |
+
|
863 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
864 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
865 |
+
|
866 |
+
virtual_tokens = self.space_virtual2point_blocks[j](
|
867 |
+
virtual_tokens, point_tokens, mask=mask
|
868 |
+
)
|
869 |
+
|
870 |
+
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
871 |
+
point_tokens = self.space_point2virtual_blocks[j](
|
872 |
+
point_tokens, virtual_tokens, mask=mask
|
873 |
+
)
|
874 |
+
|
875 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
876 |
+
tokens = space_tokens.view(B, T, N, -1).permute(
|
877 |
+
0, 2, 1, 3
|
878 |
+
) # (B T) N C -> B N T C
|
879 |
+
j += 1
|
880 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
881 |
+
|
882 |
+
flow = self.flow_head(tokens)
|
883 |
+
if self.linear_layer_for_vis_conf:
|
884 |
+
vis_conf = self.vis_conf_head(tokens)
|
885 |
+
flow = torch.cat([flow, vis_conf], dim=-1)
|
886 |
+
|
887 |
+
return flow
|
888 |
+
|
889 |
+
|
890 |
+
class MMPreNormResidual(nn.Module):
|
891 |
+
def __init__(self, dim, fn):
|
892 |
+
super().__init__()
|
893 |
+
self.fn = fn
|
894 |
+
self.norm = nn.LayerNorm(dim)
|
895 |
+
|
896 |
+
def forward(self, x):
|
897 |
+
return self.fn(self.norm(x)) + x
|
898 |
+
|
899 |
+
def MMFeedForward(dim, expansion_factor=4, dropout=0., dense=nn.Linear):
|
900 |
+
return nn.Sequential(
|
901 |
+
dense(dim, dim * expansion_factor),
|
902 |
+
nn.GELU(),
|
903 |
+
nn.Dropout(dropout),
|
904 |
+
dense(dim * expansion_factor, dim),
|
905 |
+
nn.Dropout(dropout)
|
906 |
+
)
|
907 |
+
|
908 |
+
def MLPMixer(S, input_dim, dim, output_dim, depth=6, expansion_factor=4, dropout=0., do_reduce=False):
|
909 |
+
# input is coming in as B,S,C, as standard for mlp and transformer
|
910 |
+
# chan_first treats S as the channel dim, and transforms it to a new S
|
911 |
+
# chan_last treats C as the channel dim, and transforms it to a new C
|
912 |
+
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
|
913 |
+
if do_reduce:
|
914 |
+
return nn.Sequential(
|
915 |
+
nn.Linear(input_dim, dim),
|
916 |
+
*[nn.Sequential(
|
917 |
+
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
|
918 |
+
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
|
919 |
+
) for _ in range(depth)],
|
920 |
+
nn.LayerNorm(dim),
|
921 |
+
Reduce('b n c -> b c', 'mean'),
|
922 |
+
nn.Linear(dim, output_dim)
|
923 |
+
)
|
924 |
+
else:
|
925 |
+
return nn.Sequential(
|
926 |
+
nn.Linear(input_dim, dim),
|
927 |
+
*[nn.Sequential(
|
928 |
+
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
|
929 |
+
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
|
930 |
+
) for _ in range(depth)],
|
931 |
+
)
|
932 |
+
|
933 |
+
def MLPMixerBlock(S, dim, depth=1, expansion_factor=4, dropout=0., do_reduce=False):
|
934 |
+
# input is coming in as B,S,C, as standard for mlp and transformer
|
935 |
+
# chan_first treats S as the channel dim, and transforms it to a new S
|
936 |
+
# chan_last treats C as the channel dim, and transforms it to a new C
|
937 |
+
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
|
938 |
+
return nn.Sequential(
|
939 |
+
*[nn.Sequential(
|
940 |
+
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
|
941 |
+
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
|
942 |
+
) for _ in range(depth)],
|
943 |
+
)
|
944 |
+
|
945 |
+
|
946 |
+
class MlpUpdateFormer(nn.Module):
|
947 |
+
"""
|
948 |
+
Transformer model that updates track estimates.
|
949 |
+
"""
|
950 |
+
|
951 |
+
def __init__(
|
952 |
+
self,
|
953 |
+
space_depth=6,
|
954 |
+
time_depth=6,
|
955 |
+
input_dim=320,
|
956 |
+
hidden_size=384,
|
957 |
+
num_heads=8,
|
958 |
+
output_dim=130,
|
959 |
+
mlp_ratio=4.0,
|
960 |
+
num_virtual_tracks=64,
|
961 |
+
add_space_attn=True,
|
962 |
+
linear_layer_for_vis_conf=False,
|
963 |
+
):
|
964 |
+
super().__init__()
|
965 |
+
self.out_channels = 2
|
966 |
+
self.num_heads = num_heads
|
967 |
+
self.hidden_size = hidden_size
|
968 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
969 |
+
if linear_layer_for_vis_conf:
|
970 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
|
971 |
+
self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
|
972 |
+
else:
|
973 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
974 |
+
self.num_virtual_tracks = num_virtual_tracks
|
975 |
+
self.virual_tracks = nn.Parameter(
|
976 |
+
torch.randn(1, num_virtual_tracks, 1, hidden_size)
|
977 |
+
)
|
978 |
+
self.add_space_attn = add_space_attn
|
979 |
+
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
|
980 |
+
self.time_blocks = nn.ModuleList(
|
981 |
+
[
|
982 |
+
MLPMixer(
|
983 |
+
S=16,
|
984 |
+
input_dim=hidden_size,
|
985 |
+
dim=hidden_size,
|
986 |
+
output_dim=hidden_size,
|
987 |
+
depth=1,
|
988 |
+
)
|
989 |
+
for _ in range(time_depth)
|
990 |
+
]
|
991 |
+
)
|
992 |
+
|
993 |
+
if add_space_attn:
|
994 |
+
self.space_virtual_blocks = nn.ModuleList(
|
995 |
+
[
|
996 |
+
AttnBlock(
|
997 |
+
hidden_size,
|
998 |
+
num_heads,
|
999 |
+
mlp_ratio=mlp_ratio,
|
1000 |
+
attn_class=Attention,
|
1001 |
+
)
|
1002 |
+
for _ in range(space_depth)
|
1003 |
+
]
|
1004 |
+
)
|
1005 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
1006 |
+
[
|
1007 |
+
CrossAttnBlock(
|
1008 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
1009 |
+
)
|
1010 |
+
for _ in range(space_depth)
|
1011 |
+
]
|
1012 |
+
)
|
1013 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
1014 |
+
[
|
1015 |
+
CrossAttnBlock(
|
1016 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
1017 |
+
)
|
1018 |
+
for _ in range(space_depth)
|
1019 |
+
]
|
1020 |
+
)
|
1021 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
1022 |
+
self.initialize_weights()
|
1023 |
+
|
1024 |
+
def initialize_weights(self):
|
1025 |
+
def _basic_init(module):
|
1026 |
+
if isinstance(module, nn.Linear):
|
1027 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
1028 |
+
if module.bias is not None:
|
1029 |
+
nn.init.constant_(module.bias, 0)
|
1030 |
+
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
|
1031 |
+
if self.linear_layer_for_vis_conf:
|
1032 |
+
torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
|
1033 |
+
|
1034 |
+
def _trunc_init(module):
|
1035 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
1036 |
+
if isinstance(module, nn.Linear):
|
1037 |
+
torch.nn.init.trunc_normal_(module.weight, std=0.02)
|
1038 |
+
if module.bias is not None:
|
1039 |
+
nn.init.zeros_(module.bias)
|
1040 |
+
|
1041 |
+
self.apply(_basic_init)
|
1042 |
+
|
1043 |
+
def forward(self, input_tensor, mask=None, add_space_attn=True):
|
1044 |
+
tokens = self.input_transform(input_tensor)
|
1045 |
+
|
1046 |
+
B, _, T, _ = tokens.shape
|
1047 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
1048 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
1049 |
+
|
1050 |
+
_, N, _, _ = tokens.shape
|
1051 |
+
j = 0
|
1052 |
+
layers = []
|
1053 |
+
for i in range(len(self.time_blocks)):
|
1054 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
1055 |
+
time_tokens = self.time_blocks[i](time_tokens)
|
1056 |
+
|
1057 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
1058 |
+
if (
|
1059 |
+
add_space_attn
|
1060 |
+
and hasattr(self, "space_virtual_blocks")
|
1061 |
+
and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
|
1062 |
+
):
|
1063 |
+
space_tokens = (
|
1064 |
+
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
|
1065 |
+
) # B N T C -> (B T) N C
|
1066 |
+
|
1067 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
1068 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
1069 |
+
|
1070 |
+
virtual_tokens = self.space_virtual2point_blocks[j](
|
1071 |
+
virtual_tokens, point_tokens, mask=mask
|
1072 |
+
)
|
1073 |
+
|
1074 |
+
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
1075 |
+
point_tokens = self.space_point2virtual_blocks[j](
|
1076 |
+
point_tokens, virtual_tokens, mask=mask
|
1077 |
+
)
|
1078 |
+
|
1079 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
1080 |
+
tokens = space_tokens.view(B, T, N, -1).permute(
|
1081 |
+
0, 2, 1, 3
|
1082 |
+
) # (B T) N C -> B N T C
|
1083 |
+
j += 1
|
1084 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
1085 |
+
|
1086 |
+
flow = self.flow_head(tokens)
|
1087 |
+
if self.linear_layer_for_vis_conf:
|
1088 |
+
vis_conf = self.vis_conf_head(tokens)
|
1089 |
+
flow = torch.cat([flow, vis_conf], dim=-1)
|
1090 |
+
|
1091 |
+
return flow
|
1092 |
+
|
1093 |
+
class BasicMotionEncoder(nn.Module):
|
1094 |
+
def __init__(self, corr_channel, dim=128, pdim=2):
|
1095 |
+
super(BasicMotionEncoder, self).__init__()
|
1096 |
+
self.pdim = pdim
|
1097 |
+
self.convc1 = nn.Conv2d(corr_channel, dim*4, 1, padding=0)
|
1098 |
+
self.convc2 = nn.Conv2d(dim*4, dim+dim//2, 3, padding=1)
|
1099 |
+
if pdim==2 or pdim==4:
|
1100 |
+
self.convf1 = nn.Conv2d(pdim, dim*2, 5, padding=2)
|
1101 |
+
self.convf2 = nn.Conv2d(dim*2, dim//2, 3, padding=1)
|
1102 |
+
self.conv = nn.Conv2d(dim*2, dim-pdim, 3, padding=1)
|
1103 |
+
else:
|
1104 |
+
self.conv = nn.Conv2d(dim+dim//2+pdim, dim, 3, padding=1)
|
1105 |
+
|
1106 |
+
def forward(self, flow, corr):
|
1107 |
+
cor = F.relu(self.convc1(corr))
|
1108 |
+
cor = F.relu(self.convc2(cor))
|
1109 |
+
if self.pdim==2 or self.pdim==4:
|
1110 |
+
flo = F.relu(self.convf1(flow))
|
1111 |
+
flo = F.relu(self.convf2(flo))
|
1112 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
1113 |
+
out = F.relu(self.conv(cor_flo))
|
1114 |
+
return torch.cat([out, flow], dim=1)
|
1115 |
+
else:
|
1116 |
+
# the flow is already encoded to something nice
|
1117 |
+
cor_flo = torch.cat([cor, flow], dim=1)
|
1118 |
+
return F.relu(self.conv(cor_flo))
|
1119 |
+
# return torch.cat([out, flow], dim=1)
|
1120 |
+
|
1121 |
+
def conv133_encoder(input_dim, dim, expansion_factor=4):
|
1122 |
+
return nn.Sequential(
|
1123 |
+
nn.Conv2d(input_dim, dim*expansion_factor, kernel_size=1),
|
1124 |
+
nn.GELU(),
|
1125 |
+
nn.Conv2d(dim*expansion_factor, dim*expansion_factor, kernel_size=3, padding=1),
|
1126 |
+
nn.GELU(),
|
1127 |
+
nn.Conv2d(dim*expansion_factor, dim, kernel_size=3, padding=1),
|
1128 |
+
)
|
1129 |
+
|
1130 |
+
class BasicUpdateBlock(nn.Module):
|
1131 |
+
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128):
|
1132 |
+
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
|
1133 |
+
super(BasicUpdateBlock, self).__init__()
|
1134 |
+
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim)
|
1135 |
+
self.compressor = conv1x1(2*cdim+hdim, hdim)
|
1136 |
+
|
1137 |
+
self.refine = []
|
1138 |
+
for i in range(num_blocks):
|
1139 |
+
self.refine.append(CNBlock1d(hdim, hdim))
|
1140 |
+
self.refine.append(CNBlock2d(hdim, hdim))
|
1141 |
+
self.refine = nn.ModuleList(self.refine)
|
1142 |
+
|
1143 |
+
def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True):
|
1144 |
+
BS,C,H,W = flowfeat.shape
|
1145 |
+
B = BS//S
|
1146 |
+
|
1147 |
+
# with torch.no_grad():
|
1148 |
+
motion_features = self.encoder(flow, corr)
|
1149 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1))
|
1150 |
+
|
1151 |
+
for blk in self.refine:
|
1152 |
+
flowfeat = blk(flowfeat, S)
|
1153 |
+
return flowfeat
|
1154 |
+
|
1155 |
+
class FullUpdateBlock(nn.Module):
|
1156 |
+
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=2, use_attn=False):
|
1157 |
+
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
|
1158 |
+
super(FullUpdateBlock, self).__init__()
|
1159 |
+
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim, pdim=pdim)
|
1160 |
+
|
1161 |
+
# note we have hdim==cdim
|
1162 |
+
# compressor chans:
|
1163 |
+
# dim for flowfeat
|
1164 |
+
# dim for ctxfeat
|
1165 |
+
# dim for motion_features
|
1166 |
+
# pdim for flow (if p 2, like if we give sincos(relflow))
|
1167 |
+
# 2 for visconf
|
1168 |
+
|
1169 |
+
if pdim==2:
|
1170 |
+
# hdim==cdim
|
1171 |
+
# dim for flowfeat
|
1172 |
+
# dim for ctxfeat
|
1173 |
+
# dim for motion_features
|
1174 |
+
# 2 for visconf
|
1175 |
+
self.compressor = conv1x1(2*cdim+hdim+2, hdim)
|
1176 |
+
else:
|
1177 |
+
# we concatenate the flow info again, to not lose it (e.g., from the relu)
|
1178 |
+
self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim)
|
1179 |
+
|
1180 |
+
self.refine = []
|
1181 |
+
for i in range(num_blocks):
|
1182 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn))
|
1183 |
+
self.refine.append(CNBlock2d(hdim, hdim))
|
1184 |
+
self.refine = nn.ModuleList(self.refine)
|
1185 |
+
|
1186 |
+
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
|
1187 |
+
BS,C,H,W = flowfeat.shape
|
1188 |
+
B = BS//S
|
1189 |
+
motion_features = self.encoder(flow, corr)
|
1190 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1))
|
1191 |
+
for blk in self.refine:
|
1192 |
+
flowfeat = blk(flowfeat, S)
|
1193 |
+
return flowfeat
|
1194 |
+
|
1195 |
+
class MixerUpdateBlock(nn.Module):
|
1196 |
+
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128):
|
1197 |
+
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
|
1198 |
+
super(MixerUpdateBlock, self).__init__()
|
1199 |
+
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim)
|
1200 |
+
self.compressor = conv1x1(2*cdim+hdim, hdim)
|
1201 |
+
|
1202 |
+
self.refine = []
|
1203 |
+
for i in range(num_blocks):
|
1204 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_mixer=True))
|
1205 |
+
self.refine.append(CNBlock2d(hdim, hdim))
|
1206 |
+
self.refine = nn.ModuleList(self.refine)
|
1207 |
+
|
1208 |
+
def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True):
|
1209 |
+
BS,C,H,W = flowfeat.shape
|
1210 |
+
B = BS//S
|
1211 |
+
|
1212 |
+
# with torch.no_grad():
|
1213 |
+
motion_features = self.encoder(flow, corr)
|
1214 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1))
|
1215 |
+
|
1216 |
+
for ii, blk in enumerate(self.refine):
|
1217 |
+
flowfeat = blk(flowfeat, S)
|
1218 |
+
return flowfeat
|
1219 |
+
|
1220 |
+
class FacUpdateBlock(nn.Module):
|
1221 |
+
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=84, use_attn=False):
|
1222 |
+
super(FacUpdateBlock, self).__init__()
|
1223 |
+
self.corr_encoder = conv133_encoder(corr_channel, cdim)
|
1224 |
+
# note we have hdim==cdim
|
1225 |
+
# compressor chans:
|
1226 |
+
# dim for flowfeat
|
1227 |
+
# dim for ctxfeat
|
1228 |
+
# dim for corr
|
1229 |
+
# pdim for flow
|
1230 |
+
# 2 for visconf
|
1231 |
+
self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim)
|
1232 |
+
self.refine = []
|
1233 |
+
for i in range(num_blocks):
|
1234 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn))
|
1235 |
+
self.refine.append(CNBlock2d(hdim, hdim))
|
1236 |
+
self.refine = nn.ModuleList(self.refine)
|
1237 |
+
|
1238 |
+
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
|
1239 |
+
BS,C,H,W = flowfeat.shape
|
1240 |
+
B = BS//S
|
1241 |
+
corr = self.corr_encoder(corr)
|
1242 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corr, visconf, flow], dim=1))
|
1243 |
+
for blk in self.refine:
|
1244 |
+
flowfeat = blk(flowfeat, S)
|
1245 |
+
return flowfeat
|
1246 |
+
|
1247 |
+
class CleanUpdateBlock(nn.Module):
|
1248 |
+
def __init__(self, corr_channel, num_blocks, cdim=128, hdim=256, pdim=84, use_attn=False, use_layer_scale=True):
|
1249 |
+
super(CleanUpdateBlock, self).__init__()
|
1250 |
+
self.corr_encoder = conv133_encoder(corr_channel, cdim)
|
1251 |
+
# compressor chans:
|
1252 |
+
# cdim for flowfeat
|
1253 |
+
# cdim for ctxfeat
|
1254 |
+
# cdim for corrfeat
|
1255 |
+
# pdim for flow
|
1256 |
+
# 2 for visconf
|
1257 |
+
self.compressor = conv1x1(3*cdim+pdim+2, hdim)
|
1258 |
+
self.refine = []
|
1259 |
+
for i in range(num_blocks):
|
1260 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_layer_scale=use_layer_scale))
|
1261 |
+
self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale))
|
1262 |
+
self.refine = nn.ModuleList(self.refine)
|
1263 |
+
self.final_conv = conv1x1(hdim, cdim)
|
1264 |
+
|
1265 |
+
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
|
1266 |
+
BS,C,H,W = flowfeat.shape
|
1267 |
+
B = BS//S
|
1268 |
+
corrfeat = self.corr_encoder(corr)
|
1269 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corrfeat, flow, visconf], dim=1))
|
1270 |
+
for blk in self.refine:
|
1271 |
+
flowfeat = blk(flowfeat, S)
|
1272 |
+
flowfeat = self.final_conv(flowfeat)
|
1273 |
+
return flowfeat
|
1274 |
+
|
1275 |
+
class RelUpdateBlock(nn.Module):
|
1276 |
+
def __init__(self, corr_channel, num_blocks, cdim=128, hdim=128, pdim=4, use_attn=True, use_mixer=False, use_conv=False, use_convb=False, use_layer_scale=True, no_time=False, no_space=False, no_ctx=False):
|
1277 |
+
super(RelUpdateBlock, self).__init__()
|
1278 |
+
self.motion_encoder = BasicMotionEncoder(corr_channel, dim=hdim, pdim=pdim) # B,hdim,H,W
|
1279 |
+
self.no_ctx = no_ctx
|
1280 |
+
if no_ctx:
|
1281 |
+
self.compressor = conv1x1(cdim+hdim+2, hdim)
|
1282 |
+
else:
|
1283 |
+
self.compressor = conv1x1(2*cdim+hdim+2, hdim)
|
1284 |
+
self.refine = []
|
1285 |
+
for i in range(num_blocks):
|
1286 |
+
if not no_time:
|
1287 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_mixer=use_mixer, use_conv=use_conv, use_convb=use_convb, use_layer_scale=use_layer_scale))
|
1288 |
+
if not no_space:
|
1289 |
+
self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale))
|
1290 |
+
self.refine = nn.ModuleList(self.refine)
|
1291 |
+
self.final_conv = conv1x1(hdim, cdim)
|
1292 |
+
|
1293 |
+
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
|
1294 |
+
BS,C,H,W = flowfeat.shape
|
1295 |
+
B = BS//S
|
1296 |
+
motion_features = self.motion_encoder(flow, corr)
|
1297 |
+
if self.no_ctx:
|
1298 |
+
flowfeat = self.compressor(torch.cat([flowfeat, motion_features, visconf], dim=1))
|
1299 |
+
else:
|
1300 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1))
|
1301 |
+
for blk in self.refine:
|
1302 |
+
flowfeat = blk(flowfeat, S)
|
1303 |
+
flowfeat = self.final_conv(flowfeat)
|
1304 |
+
return flowfeat
|
utils/basic.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
EPS = 1e-6
|
5 |
+
|
6 |
+
def sub2ind(height, width, y, x):
|
7 |
+
return y*width + x
|
8 |
+
|
9 |
+
def ind2sub(height, width, ind):
|
10 |
+
y = ind // width
|
11 |
+
x = ind % width
|
12 |
+
return y, x
|
13 |
+
|
14 |
+
def get_lr_str(lr):
|
15 |
+
lrn = "%.1e" % lr # e.g., 5.0e-04
|
16 |
+
lrn = lrn[0] + lrn[3:5] + lrn[-1] # e.g., 5e-4
|
17 |
+
return lrn
|
18 |
+
|
19 |
+
def strnum(x):
|
20 |
+
s = '%g' % x
|
21 |
+
if '.' in s:
|
22 |
+
if x < 1.0:
|
23 |
+
s = s[s.index('.'):]
|
24 |
+
s = s[:min(len(s),4)]
|
25 |
+
return s
|
26 |
+
|
27 |
+
def assert_same_shape(t1, t2):
|
28 |
+
for (x, y) in zip(list(t1.shape), list(t2.shape)):
|
29 |
+
assert(x==y)
|
30 |
+
|
31 |
+
def mkdir(path):
|
32 |
+
if not os.path.exists(path):
|
33 |
+
os.makedirs(path)
|
34 |
+
|
35 |
+
def print_stats(name, tensor):
|
36 |
+
shape = tensor.shape
|
37 |
+
tensor = tensor.detach().cpu().numpy()
|
38 |
+
print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)
|
39 |
+
|
40 |
+
def normalize_single(d):
|
41 |
+
# d is a whatever shape torch tensor
|
42 |
+
dmin = torch.min(d)
|
43 |
+
dmax = torch.max(d)
|
44 |
+
d = (d-dmin)/(EPS+(dmax-dmin))
|
45 |
+
return d
|
46 |
+
|
47 |
+
def normalize(d):
|
48 |
+
# d is B x whatever. normalize within each element of the batch
|
49 |
+
out = torch.zeros(d.size(), dtype=d.dtype, device=d.device)
|
50 |
+
B = list(d.size())[0]
|
51 |
+
for b in list(range(B)):
|
52 |
+
out[b] = normalize_single(d[b])
|
53 |
+
return out
|
54 |
+
|
55 |
+
def meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False):
|
56 |
+
# returns a meshgrid sized B x Y x X
|
57 |
+
|
58 |
+
grid_y = torch.linspace(0.0, Y-1, Y, device=torch.device(device))
|
59 |
+
grid_y = torch.reshape(grid_y, [1, Y, 1])
|
60 |
+
grid_y = grid_y.repeat(B, 1, X)
|
61 |
+
|
62 |
+
grid_x = torch.linspace(0.0, X-1, X, device=torch.device(device))
|
63 |
+
grid_x = torch.reshape(grid_x, [1, 1, X])
|
64 |
+
grid_x = grid_x.repeat(B, Y, 1)
|
65 |
+
|
66 |
+
if norm:
|
67 |
+
grid_y, grid_x = normalize_grid2d(
|
68 |
+
grid_y, grid_x, Y, X)
|
69 |
+
|
70 |
+
if stack:
|
71 |
+
# note we stack in xy order
|
72 |
+
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
73 |
+
if on_chans:
|
74 |
+
grid = torch.stack([grid_x, grid_y], dim=1)
|
75 |
+
else:
|
76 |
+
grid = torch.stack([grid_x, grid_y], dim=-1)
|
77 |
+
return grid
|
78 |
+
else:
|
79 |
+
return grid_y, grid_x
|
80 |
+
|
81 |
+
def gridcloud2d(B, Y, X, norm=False, device='cuda'):
|
82 |
+
# we want to sample for each location in the grid
|
83 |
+
grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device)
|
84 |
+
x = torch.reshape(grid_x, [B, -1])
|
85 |
+
y = torch.reshape(grid_y, [B, -1])
|
86 |
+
# these are B x N
|
87 |
+
xy = torch.stack([x, y], dim=2)
|
88 |
+
# this is B x N x 2
|
89 |
+
return xy
|
90 |
+
|
91 |
+
def reduce_masked_mean(x, mask, dim=None, keepdim=False, broadcast=False):
|
92 |
+
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
|
93 |
+
# returns shape-1
|
94 |
+
# axis can be a list of axes
|
95 |
+
if not broadcast:
|
96 |
+
for (a,b) in zip(x.size(), mask.size()):
|
97 |
+
if not a==b:
|
98 |
+
print('some shape mismatch:', x.shape, mask.shape)
|
99 |
+
assert(a==b) # some shape mismatch!
|
100 |
+
# assert(x.size() == mask.size())
|
101 |
+
prod = x*mask
|
102 |
+
if dim is None:
|
103 |
+
numer = torch.sum(prod)
|
104 |
+
denom = EPS+torch.sum(mask)
|
105 |
+
else:
|
106 |
+
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
107 |
+
denom = EPS+torch.sum(mask, dim=dim, keepdim=keepdim)
|
108 |
+
mean = numer/denom
|
109 |
+
return mean
|
110 |
+
|
111 |
+
def reduce_masked_median(x, mask, keep_batch=False):
|
112 |
+
# x and mask are the same shape
|
113 |
+
assert(x.size() == mask.size())
|
114 |
+
device = x.device
|
115 |
+
|
116 |
+
B = list(x.shape)[0]
|
117 |
+
x = x.detach().cpu().numpy()
|
118 |
+
mask = mask.detach().cpu().numpy()
|
119 |
+
|
120 |
+
if keep_batch:
|
121 |
+
x = np.reshape(x, [B, -1])
|
122 |
+
mask = np.reshape(mask, [B, -1])
|
123 |
+
meds = np.zeros([B], np.float32)
|
124 |
+
for b in list(range(B)):
|
125 |
+
xb = x[b]
|
126 |
+
mb = mask[b]
|
127 |
+
if np.sum(mb) > 0:
|
128 |
+
xb = xb[mb > 0]
|
129 |
+
meds[b] = np.median(xb)
|
130 |
+
else:
|
131 |
+
meds[b] = np.nan
|
132 |
+
meds = torch.from_numpy(meds).to(device)
|
133 |
+
return meds.float()
|
134 |
+
else:
|
135 |
+
x = np.reshape(x, [-1])
|
136 |
+
mask = np.reshape(mask, [-1])
|
137 |
+
if np.sum(mask) > 0:
|
138 |
+
x = x[mask > 0]
|
139 |
+
med = np.median(x)
|
140 |
+
else:
|
141 |
+
med = np.nan
|
142 |
+
med = np.array([med], np.float32)
|
143 |
+
med = torch.from_numpy(med).to(device)
|
144 |
+
return med.float()
|
utils/data.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import dataclasses
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Any, Optional, Dict
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass(eq=False)
|
9 |
+
class VideoData:
|
10 |
+
"""
|
11 |
+
Dataclass for storing video tracks data.
|
12 |
+
"""
|
13 |
+
|
14 |
+
video: torch.Tensor # B,S,C,H,W
|
15 |
+
trajs: torch.Tensor # B,S,N,2
|
16 |
+
visibs: torch.Tensor # B,S,N
|
17 |
+
valids: Optional[torch.Tensor] = None # B,S,N
|
18 |
+
seq_name: Optional[str] = None
|
19 |
+
dname: Optional[str] = None
|
20 |
+
aug_video: Optional[torch.Tensor] = None
|
21 |
+
|
22 |
+
|
23 |
+
def collate_fn(batch):
|
24 |
+
"""
|
25 |
+
Collate function for video tracks data.
|
26 |
+
"""
|
27 |
+
video = torch.stack([b.video for b in batch], dim=0)
|
28 |
+
trajs = torch.stack([b.trajs for b in batch], dim=0)
|
29 |
+
visibs = torch.stack([b.visibs for b in batch], dim=0)
|
30 |
+
seq_name = [b.seq_name for b in batch]
|
31 |
+
dname = [b.dname for b in batch]
|
32 |
+
|
33 |
+
return VideoData(
|
34 |
+
video=video,
|
35 |
+
trajs=trajs,
|
36 |
+
visibs=visibs,
|
37 |
+
seq_name=seq_name,
|
38 |
+
dname=dname,
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def collate_fn_train(batch):
|
43 |
+
"""
|
44 |
+
Collate function for video tracks data during training.
|
45 |
+
"""
|
46 |
+
gotit = [gotit for _, gotit in batch]
|
47 |
+
video = torch.stack([b.video for b, _ in batch], dim=0)
|
48 |
+
trajs = torch.stack([b.trajs for b, _ in batch], dim=0)
|
49 |
+
visibs = torch.stack([b.visibs for b, _ in batch], dim=0)
|
50 |
+
valids = torch.stack([b.valids for b, _ in batch], dim=0)
|
51 |
+
seq_name = [b.seq_name for b, _ in batch]
|
52 |
+
dname = [b.dname for b, _ in batch]
|
53 |
+
|
54 |
+
return (
|
55 |
+
VideoData(
|
56 |
+
video=video,
|
57 |
+
trajs=trajs,
|
58 |
+
visibs=visibs,
|
59 |
+
valids=valids,
|
60 |
+
seq_name=seq_name,
|
61 |
+
dname=dname,
|
62 |
+
),
|
63 |
+
gotit,
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def try_to_cuda(t: Any) -> Any:
|
68 |
+
"""
|
69 |
+
Try to move the input variable `t` to a cuda device.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
t: Input.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
t_cuda: `t` moved to a cuda device, if supported.
|
76 |
+
"""
|
77 |
+
try:
|
78 |
+
t = t.float().cuda()
|
79 |
+
except AttributeError:
|
80 |
+
pass
|
81 |
+
return t
|
82 |
+
|
83 |
+
|
84 |
+
def dataclass_to_cuda_(obj):
|
85 |
+
"""
|
86 |
+
Move all contents of a dataclass to cuda inplace if supported.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
batch: Input dataclass.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
batch_cuda: `batch` moved to a cuda device, if supported.
|
93 |
+
"""
|
94 |
+
for f in dataclasses.fields(obj):
|
95 |
+
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
96 |
+
return obj
|
utils/improc.py
ADDED
@@ -0,0 +1,1103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import utils.basic
|
4 |
+
import utils.py
|
5 |
+
from sklearn.decomposition import PCA
|
6 |
+
from matplotlib import cm
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import cv2
|
9 |
+
import torch.nn.functional as F
|
10 |
+
EPS = 1e-6
|
11 |
+
|
12 |
+
from skimage.color import (
|
13 |
+
rgb2lab, rgb2yuv, rgb2ycbcr, lab2rgb, yuv2rgb, ycbcr2rgb,
|
14 |
+
rgb2hsv, hsv2rgb, rgb2xyz, xyz2rgb, rgb2hed, hed2rgb)
|
15 |
+
|
16 |
+
def _convert(input_, type_):
|
17 |
+
return {
|
18 |
+
'float': input_.float(),
|
19 |
+
'double': input_.double(),
|
20 |
+
}.get(type_, input_)
|
21 |
+
|
22 |
+
def _generic_transform_sk_3d(transform, in_type='', out_type=''):
|
23 |
+
def apply_transform_individual(input_):
|
24 |
+
device = input_.device
|
25 |
+
input_ = input_.cpu()
|
26 |
+
input_ = _convert(input_, in_type)
|
27 |
+
|
28 |
+
input_ = input_.permute(1, 2, 0).detach().numpy()
|
29 |
+
transformed = transform(input_)
|
30 |
+
output = torch.from_numpy(transformed).float().permute(2, 0, 1)
|
31 |
+
output = _convert(output, out_type)
|
32 |
+
return output.to(device)
|
33 |
+
|
34 |
+
def apply_transform(input_):
|
35 |
+
to_stack = []
|
36 |
+
for image in input_:
|
37 |
+
to_stack.append(apply_transform_individual(image))
|
38 |
+
return torch.stack(to_stack)
|
39 |
+
return apply_transform
|
40 |
+
|
41 |
+
hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb)
|
42 |
+
|
43 |
+
def flow2color(flow, clip=0.0):
|
44 |
+
B, C, H, W = list(flow.size())
|
45 |
+
assert(C==2)
|
46 |
+
flow = flow[0:1].detach()
|
47 |
+
if clip==0:
|
48 |
+
clip = torch.max(torch.abs(flow)).item()
|
49 |
+
flow = torch.clamp(flow, -clip, clip)/clip
|
50 |
+
radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) # B,1,H,W
|
51 |
+
radius_clipped = torch.clamp(radius, 0.0, 1.0)
|
52 |
+
angle = torch.atan2(-flow[:, 1:2], -flow[:, 0:1]) / np.pi # B,1,H,W
|
53 |
+
hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0)
|
54 |
+
saturation = torch.ones_like(hue) * 0.75
|
55 |
+
value = radius_clipped
|
56 |
+
hsv = torch.cat([hue, saturation, value], dim=1) # B,3,H,W
|
57 |
+
flow = hsv_to_rgb(hsv)
|
58 |
+
flow = (flow*255.0).type(torch.ByteTensor)
|
59 |
+
return flow
|
60 |
+
|
61 |
+
COLORMAP_FILE = "./utils/bremm.png"
|
62 |
+
class ColorMap2d:
|
63 |
+
def __init__(self, filename=None):
|
64 |
+
self._colormap_file = filename or COLORMAP_FILE
|
65 |
+
self._img = (plt.imread(self._colormap_file)*255).astype(np.uint8)
|
66 |
+
|
67 |
+
self._height = self._img.shape[0]
|
68 |
+
self._width = self._img.shape[1]
|
69 |
+
|
70 |
+
def __call__(self, X):
|
71 |
+
assert len(X.shape) == 2
|
72 |
+
output = np.zeros((X.shape[0], 3), dtype=np.uint8)
|
73 |
+
for i in range(X.shape[0]):
|
74 |
+
x, y = X[i, :]
|
75 |
+
xp = int((self._width-1) * x)
|
76 |
+
yp = int((self._height-1) * y)
|
77 |
+
xp = np.clip(xp, 0, self._width-1)
|
78 |
+
yp = np.clip(yp, 0, self._height-1)
|
79 |
+
output[i, :] = self._img[yp, xp]
|
80 |
+
return output
|
81 |
+
|
82 |
+
def get_2d_colors(xys, H, W):
|
83 |
+
N,D = xys.shape
|
84 |
+
assert(D==2)
|
85 |
+
bremm = ColorMap2d()
|
86 |
+
xys[:,0] /= float(W-1)
|
87 |
+
xys[:,1] /= float(H-1)
|
88 |
+
colors = bremm(xys)
|
89 |
+
# print('colors', colors)
|
90 |
+
# colors = (colors[0]*255).astype(np.uint8)
|
91 |
+
# colors = (int(colors[0]),int(colors[1]),int(colors[2]))
|
92 |
+
return colors
|
93 |
+
|
94 |
+
|
95 |
+
def get_n_colors(N, sequential=False):
|
96 |
+
label_colors = []
|
97 |
+
for ii in range(N):
|
98 |
+
if sequential:
|
99 |
+
rgb = cm.winter(ii/(N-1))
|
100 |
+
rgb = (np.array(rgb) * 255).astype(np.uint8)[:3]
|
101 |
+
else:
|
102 |
+
rgb = np.zeros(3)
|
103 |
+
while np.sum(rgb) < 128: # ensure min brightness
|
104 |
+
rgb = np.random.randint(0,256,3)
|
105 |
+
label_colors.append(rgb)
|
106 |
+
return label_colors
|
107 |
+
|
108 |
+
def pca_embed(emb, keep, valid=None):
|
109 |
+
# helper function for reduce_emb
|
110 |
+
# emb is B,C,H,W
|
111 |
+
# keep is the number of principal components to keep
|
112 |
+
emb = emb + EPS
|
113 |
+
emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C
|
114 |
+
|
115 |
+
if valid:
|
116 |
+
valid = valid.cpu().detach().numpy().reshape((H*W))
|
117 |
+
|
118 |
+
emb_reduced = list()
|
119 |
+
|
120 |
+
B, H, W, C = np.shape(emb)
|
121 |
+
for img in emb:
|
122 |
+
if np.isnan(img).any():
|
123 |
+
emb_reduced.append(np.zeros([H, W, keep]))
|
124 |
+
continue
|
125 |
+
|
126 |
+
pixels_kd = np.reshape(img, (H*W, C))
|
127 |
+
|
128 |
+
if valid:
|
129 |
+
pixels_kd_pca = pixels_kd[valid]
|
130 |
+
else:
|
131 |
+
pixels_kd_pca = pixels_kd
|
132 |
+
|
133 |
+
P = PCA(keep)
|
134 |
+
P.fit(pixels_kd_pca)
|
135 |
+
|
136 |
+
if valid:
|
137 |
+
pixels3d = P.transform(pixels_kd)*valid
|
138 |
+
else:
|
139 |
+
pixels3d = P.transform(pixels_kd)
|
140 |
+
|
141 |
+
out_img = np.reshape(pixels3d, [H,W,keep]).astype(np.float32)
|
142 |
+
if np.isnan(out_img).any():
|
143 |
+
emb_reduced.append(np.zeros([H, W, keep]))
|
144 |
+
continue
|
145 |
+
|
146 |
+
emb_reduced.append(out_img)
|
147 |
+
|
148 |
+
emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32)
|
149 |
+
|
150 |
+
return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2)
|
151 |
+
|
152 |
+
def pca_embed_together(emb, keep):
|
153 |
+
# emb is B,C,H,W
|
154 |
+
# keep is the number of principal components to keep
|
155 |
+
emb = emb + EPS
|
156 |
+
emb = emb.permute(0, 2, 3, 1).cpu().detach().float().numpy() #this is B x H x W x C
|
157 |
+
|
158 |
+
B, H, W, C = np.shape(emb)
|
159 |
+
if np.isnan(emb).any():
|
160 |
+
return torch.zeros(B, keep, H, W)
|
161 |
+
|
162 |
+
pixelskd = np.reshape(emb, (B*H*W, C))
|
163 |
+
P = PCA(keep)
|
164 |
+
P.fit(pixelskd)
|
165 |
+
pixels3d = P.transform(pixelskd)
|
166 |
+
out_img = np.reshape(pixels3d, [B,H,W,keep]).astype(np.float32)
|
167 |
+
|
168 |
+
if np.isnan(out_img).any():
|
169 |
+
return torch.zeros(B, keep, H, W)
|
170 |
+
|
171 |
+
return torch.from_numpy(out_img).permute(0, 3, 1, 2)
|
172 |
+
|
173 |
+
def reduce_emb(emb, valid=None, inbound=None, together=False):
|
174 |
+
S, C, H, W = list(emb.size())
|
175 |
+
keep = 4
|
176 |
+
|
177 |
+
if together:
|
178 |
+
reduced_emb = pca_embed_together(emb, keep)
|
179 |
+
else:
|
180 |
+
reduced_emb = pca_embed(emb, keep, valid) #not im
|
181 |
+
|
182 |
+
reduced_emb = reduced_emb[:,1:]
|
183 |
+
reduced_emb = utils.basic.normalize(reduced_emb) - 0.5
|
184 |
+
if inbound is not None:
|
185 |
+
emb_inbound = emb*inbound
|
186 |
+
else:
|
187 |
+
emb_inbound = None
|
188 |
+
|
189 |
+
return reduced_emb, emb_inbound
|
190 |
+
|
191 |
+
def get_feat_pca(feat, valid=None):
|
192 |
+
B, C, D, W = list(feat.size())
|
193 |
+
pca, _ = reduce_emb(feat, valid=valid,inbound=None, together=True)
|
194 |
+
return pca
|
195 |
+
|
196 |
+
def gif_and_tile(ims, just_gif=False):
|
197 |
+
S = len(ims)
|
198 |
+
# each im is B x H x W x C
|
199 |
+
# i want a gif in the left, and the tiled frames on the right
|
200 |
+
# for the gif tool, this means making a B x S x H x W tensor
|
201 |
+
# where the leftmost part is sequential and the rest is tiled
|
202 |
+
gif = torch.stack(ims, dim=1)
|
203 |
+
if just_gif:
|
204 |
+
return gif
|
205 |
+
til = torch.cat(ims, dim=2)
|
206 |
+
til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1)
|
207 |
+
im = torch.cat([gif, til], dim=3)
|
208 |
+
return im
|
209 |
+
|
210 |
+
def preprocess_color(x):
|
211 |
+
if isinstance(x, np.ndarray):
|
212 |
+
return x.astype(np.float32) * 1./255 - 0.5
|
213 |
+
else:
|
214 |
+
return x.float() * 1./255 - 0.5
|
215 |
+
|
216 |
+
def back2color(i, blacken_zeros=False):
|
217 |
+
if blacken_zeros:
|
218 |
+
const = torch.tensor([-0.5])
|
219 |
+
i = torch.where(i==0.0, const.cuda() if i.is_cuda else const, i)
|
220 |
+
return back2color(i)
|
221 |
+
else:
|
222 |
+
return ((i+0.5)*255).type(torch.ByteTensor)
|
223 |
+
|
224 |
+
def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20, shadow=True):
|
225 |
+
|
226 |
+
rgb = vis.detach().cpu().numpy()[0]
|
227 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
228 |
+
rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
|
229 |
+
color = (255, 255, 255)
|
230 |
+
# print('putting frame id', frame_id)
|
231 |
+
|
232 |
+
frame_str = utils.basic.strnum(frame_id)
|
233 |
+
|
234 |
+
text_color_bg = (0,0,0)
|
235 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
236 |
+
text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
|
237 |
+
text_w, text_h = text_size
|
238 |
+
if shadow:
|
239 |
+
cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
|
240 |
+
|
241 |
+
cv2.putText(
|
242 |
+
rgb,
|
243 |
+
frame_str,
|
244 |
+
(left, top), # from left, from top
|
245 |
+
font,
|
246 |
+
scale, # font scale (float)
|
247 |
+
color,
|
248 |
+
1) # font thickness (int)
|
249 |
+
rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
250 |
+
vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
251 |
+
return vis
|
252 |
+
|
253 |
+
def draw_frame_str_on_vis(vis, frame_str, scale=0.5, left=5, top=40, shadow=True):
|
254 |
+
|
255 |
+
rgb = vis.detach().cpu().numpy()[0]
|
256 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
257 |
+
rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
|
258 |
+
color = (255, 255, 255)
|
259 |
+
|
260 |
+
text_color_bg = (0,0,0)
|
261 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
262 |
+
text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
|
263 |
+
text_w, text_h = text_size
|
264 |
+
if shadow:
|
265 |
+
cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
|
266 |
+
|
267 |
+
cv2.putText(
|
268 |
+
rgb,
|
269 |
+
frame_str,
|
270 |
+
(left, top), # from left, from top
|
271 |
+
font,
|
272 |
+
scale, # font scale (float)
|
273 |
+
color,
|
274 |
+
1) # font thickness (int)
|
275 |
+
rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
276 |
+
vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
277 |
+
return vis
|
278 |
+
|
279 |
+
class Summ_writer(object):
|
280 |
+
def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False):
|
281 |
+
self.writer = writer
|
282 |
+
self.global_step = global_step
|
283 |
+
self.log_freq = log_freq
|
284 |
+
self.scalar_freq = scalar_freq
|
285 |
+
self.fps = fps
|
286 |
+
self.just_gif = just_gif
|
287 |
+
self.maxwidth = 10000
|
288 |
+
self.save_this = (self.global_step % self.log_freq == 0)
|
289 |
+
self.scalar_freq = max(scalar_freq,1)
|
290 |
+
self.save_scalar = (self.global_step % self.scalar_freq == 0)
|
291 |
+
if self.save_this:
|
292 |
+
self.save_scalar = True
|
293 |
+
|
294 |
+
def summ_gif(self, name, tensor, blacken_zeros=False):
|
295 |
+
# tensor should be in B x S x C x H x W
|
296 |
+
|
297 |
+
assert tensor.dtype in {torch.uint8,torch.float32}
|
298 |
+
shape = list(tensor.shape)
|
299 |
+
|
300 |
+
if tensor.dtype == torch.float32:
|
301 |
+
tensor = back2color(tensor, blacken_zeros=blacken_zeros)
|
302 |
+
|
303 |
+
video_to_write = tensor[0:1]
|
304 |
+
|
305 |
+
S = video_to_write.shape[1]
|
306 |
+
if S==1:
|
307 |
+
# video_to_write is 1 x 1 x C x H x W
|
308 |
+
self.writer.add_image(name, video_to_write[0,0], global_step=self.global_step)
|
309 |
+
else:
|
310 |
+
self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step)
|
311 |
+
|
312 |
+
return video_to_write
|
313 |
+
|
314 |
+
def summ_rgbs(self, name, ims, frame_ids=None, frame_strs=None, blacken_zeros=False, only_return=False):
|
315 |
+
if self.save_this:
|
316 |
+
|
317 |
+
ims = gif_and_tile(ims, just_gif=self.just_gif)
|
318 |
+
vis = ims
|
319 |
+
|
320 |
+
assert vis.dtype in {torch.uint8,torch.float32}
|
321 |
+
|
322 |
+
if vis.dtype == torch.float32:
|
323 |
+
vis = back2color(vis, blacken_zeros)
|
324 |
+
|
325 |
+
B, S, C, H, W = list(vis.shape)
|
326 |
+
|
327 |
+
if frame_ids is not None:
|
328 |
+
assert(len(frame_ids)==S)
|
329 |
+
for s in range(S):
|
330 |
+
vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
|
331 |
+
|
332 |
+
if frame_strs is not None:
|
333 |
+
assert(len(frame_strs)==S)
|
334 |
+
for s in range(S):
|
335 |
+
vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s])
|
336 |
+
|
337 |
+
if int(W) > self.maxwidth:
|
338 |
+
vis = vis[:,:,:,:self.maxwidth]
|
339 |
+
|
340 |
+
if only_return:
|
341 |
+
return vis
|
342 |
+
else:
|
343 |
+
return self.summ_gif(name, vis, blacken_zeros)
|
344 |
+
|
345 |
+
def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, frame_str=None, only_return=False, halfres=False, shadow=True):
|
346 |
+
if self.save_this:
|
347 |
+
assert ims.dtype in {torch.uint8,torch.float32}
|
348 |
+
|
349 |
+
if ims.dtype == torch.float32:
|
350 |
+
ims = back2color(ims, blacken_zeros)
|
351 |
+
|
352 |
+
#ims is B x C x H x W
|
353 |
+
vis = ims[0:1] # just the first one
|
354 |
+
B, C, H, W = list(vis.shape)
|
355 |
+
|
356 |
+
if halfres:
|
357 |
+
vis = F.interpolate(vis, scale_factor=0.5)
|
358 |
+
|
359 |
+
if frame_id is not None:
|
360 |
+
vis = draw_frame_id_on_vis(vis, frame_id, shadow=shadow)
|
361 |
+
|
362 |
+
if frame_str is not None:
|
363 |
+
vis = draw_frame_str_on_vis(vis, frame_str, shadow=shadow)
|
364 |
+
|
365 |
+
if int(W) > self.maxwidth:
|
366 |
+
vis = vis[:,:,:,:self.maxwidth]
|
367 |
+
|
368 |
+
if only_return:
|
369 |
+
return vis
|
370 |
+
else:
|
371 |
+
return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros)
|
372 |
+
|
373 |
+
def flow2color(self, flow, clip=0.0):
|
374 |
+
B, C, H, W = list(flow.size())
|
375 |
+
assert(C==2)
|
376 |
+
flow = flow[0:1].detach()
|
377 |
+
|
378 |
+
if False:
|
379 |
+
flow = flow[0].detach().cpu().permute(1,2,0).numpy() # H,W,2
|
380 |
+
if clip > 0:
|
381 |
+
clip_flow = clip
|
382 |
+
else:
|
383 |
+
clip_flow = None
|
384 |
+
im = utils.py.flow_to_image(flow, clip_flow=clip_flow, convert_to_bgr=True)
|
385 |
+
# im = utils.py.flow_to_image(flow, convert_to_bgr=True)
|
386 |
+
im = torch.from_numpy(im).permute(2,0,1).unsqueeze(0).byte() # 1,3,H,W
|
387 |
+
im = torch.flip(im, dims=[1]).clone() # BGR
|
388 |
+
|
389 |
+
# # i prefer black bkg
|
390 |
+
# white_pixels = (im == 255).all(dim=1, keepdim=True)
|
391 |
+
# im[white_pixels.expand(-1, 3, -1, -1)] = 0
|
392 |
+
|
393 |
+
return im
|
394 |
+
|
395 |
+
# flow_abs = torch.abs(flow)
|
396 |
+
# flow_mean = flow_abs.mean(dim=[1,2,3])
|
397 |
+
# flow_std = flow_abs.std(dim=[1,2,3])
|
398 |
+
if clip==0:
|
399 |
+
clip = torch.max(torch.abs(flow)).item()
|
400 |
+
|
401 |
+
# if clip:
|
402 |
+
flow = torch.clamp(flow, -clip, clip)/clip
|
403 |
+
# else:
|
404 |
+
# # # Apply some kind of normalization. Divide by the perceived maximum (mean + std*2)
|
405 |
+
# # flow_max = flow_mean + flow_std*2 + 1e-10
|
406 |
+
# # for b in range(B):
|
407 |
+
# # flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1)
|
408 |
+
|
409 |
+
# flow_max = torch.max(flow_abs[b])
|
410 |
+
# for b in range(B):
|
411 |
+
# flow[b] = flow[b].clamp(-flow_max.item(), flow_max.item()) / flow_max[b].clamp(min=1)
|
412 |
+
|
413 |
+
|
414 |
+
radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) #B x 1 x H x W
|
415 |
+
radius_clipped = torch.clamp(radius, 0.0, 1.0)
|
416 |
+
|
417 |
+
angle = torch.atan2(-flow[:, 1:2], -flow[:, 0:1]) / np.pi # B x 1 x H x W
|
418 |
+
|
419 |
+
hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0)
|
420 |
+
# hue = torch.mod(angle / (2 * np.pi) + 1.0, 1.0)
|
421 |
+
|
422 |
+
saturation = torch.ones_like(hue) * 0.75
|
423 |
+
value = radius_clipped
|
424 |
+
hsv = torch.cat([hue, saturation, value], dim=1) #B x 3 x H x W
|
425 |
+
|
426 |
+
#flow = tf.image.hsv_to_rgb(hsv)
|
427 |
+
flow = hsv_to_rgb(hsv)
|
428 |
+
flow = (flow*255.0).type(torch.ByteTensor)
|
429 |
+
# flow = torch.flip(flow, dims=[1]).clone() # BGR
|
430 |
+
return flow
|
431 |
+
|
432 |
+
def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None, frame_str=None, shadow=True):
|
433 |
+
# flow is B x C x D x W
|
434 |
+
if self.save_this:
|
435 |
+
return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id, frame_str=frame_str, shadow=shadow)
|
436 |
+
else:
|
437 |
+
return None
|
438 |
+
|
439 |
+
def summ_oneds(self, name, ims, frame_ids=None, frame_strs=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False):
|
440 |
+
if self.save_this:
|
441 |
+
if bev:
|
442 |
+
B, C, H, _, W = list(ims[0].shape)
|
443 |
+
if reduce_max:
|
444 |
+
ims = [torch.max(im, dim=3)[0] for im in ims]
|
445 |
+
else:
|
446 |
+
ims = [torch.mean(im, dim=3) for im in ims]
|
447 |
+
elif fro:
|
448 |
+
B, C, _, H, W = list(ims[0].shape)
|
449 |
+
if reduce_max:
|
450 |
+
ims = [torch.max(im, dim=2)[0] for im in ims]
|
451 |
+
else:
|
452 |
+
ims = [torch.mean(im, dim=2) for im in ims]
|
453 |
+
|
454 |
+
|
455 |
+
if len(ims) != 1: # sequence
|
456 |
+
im = gif_and_tile(ims, just_gif=self.just_gif)
|
457 |
+
else:
|
458 |
+
im = torch.stack(ims, dim=1) # single frame
|
459 |
+
|
460 |
+
B, S, C, H, W = list(im.shape)
|
461 |
+
|
462 |
+
if logvis and max_val:
|
463 |
+
max_val = np.log(max_val)
|
464 |
+
im = torch.log(torch.clamp(im, 0)+1.0)
|
465 |
+
im = torch.clamp(im, 0, max_val)
|
466 |
+
im = im/max_val
|
467 |
+
norm = False
|
468 |
+
elif max_val:
|
469 |
+
im = torch.clamp(im, 0, max_val)
|
470 |
+
im = im/max_val
|
471 |
+
norm = False
|
472 |
+
|
473 |
+
if norm:
|
474 |
+
# normalize before oned2inferno,
|
475 |
+
# so that the ranges are similar within B across S
|
476 |
+
im = utils.basic.normalize(im)
|
477 |
+
|
478 |
+
im = im.view(B*S, C, H, W)
|
479 |
+
vis = oned2inferno(im, norm=norm, do_colorize=do_colorize)
|
480 |
+
vis = vis.view(B, S, 3, H, W)
|
481 |
+
|
482 |
+
if frame_ids is not None:
|
483 |
+
assert(len(frame_ids)==S)
|
484 |
+
for s in range(S):
|
485 |
+
vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
|
486 |
+
|
487 |
+
if frame_strs is not None:
|
488 |
+
assert(len(frame_strs)==S)
|
489 |
+
for s in range(S):
|
490 |
+
vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s])
|
491 |
+
|
492 |
+
if W > self.maxwidth:
|
493 |
+
vis = vis[...,:self.maxwidth]
|
494 |
+
|
495 |
+
if only_return:
|
496 |
+
return vis
|
497 |
+
else:
|
498 |
+
self.summ_gif(name, vis)
|
499 |
+
|
500 |
+
def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, frame_str=None, only_return=False, shadow=True):
|
501 |
+
if self.save_this:
|
502 |
+
|
503 |
+
if bev:
|
504 |
+
B, C, H, _, W = list(im.shape)
|
505 |
+
if max_along_y:
|
506 |
+
im = torch.max(im, dim=3)[0]
|
507 |
+
else:
|
508 |
+
im = torch.mean(im, dim=3)
|
509 |
+
elif fro:
|
510 |
+
B, C, _, H, W = list(im.shape)
|
511 |
+
if max_along_y:
|
512 |
+
im = torch.max(im, dim=2)[0]
|
513 |
+
else:
|
514 |
+
im = torch.mean(im, dim=2)
|
515 |
+
else:
|
516 |
+
B, C, H, W = list(im.shape)
|
517 |
+
|
518 |
+
im = im[0:1] # just the first one
|
519 |
+
assert(C==1)
|
520 |
+
|
521 |
+
if logvis and max_val:
|
522 |
+
max_val = np.log(max_val)
|
523 |
+
im = torch.log(im)
|
524 |
+
im = torch.clamp(im, 0, max_val)
|
525 |
+
im = im/max_val
|
526 |
+
norm = False
|
527 |
+
elif max_val:
|
528 |
+
im = torch.clamp(im, 0, max_val)/max_val
|
529 |
+
norm = False
|
530 |
+
|
531 |
+
vis = oned2inferno(im, norm=norm)
|
532 |
+
if W > self.maxwidth:
|
533 |
+
vis = vis[...,:self.maxwidth]
|
534 |
+
return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, frame_str=frame_str, only_return=only_return, shadow=shadow)
|
535 |
+
|
536 |
+
|
537 |
+
def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None, frame_strs=None):
|
538 |
+
if self.save_this:
|
539 |
+
if valids is not None:
|
540 |
+
valids = torch.stack(valids, dim=1)
|
541 |
+
|
542 |
+
feats = torch.stack(feats, dim=1)
|
543 |
+
# feats leads with B x S x C
|
544 |
+
|
545 |
+
if feats.ndim==6:
|
546 |
+
|
547 |
+
# feats is B x S x C x D x H x W
|
548 |
+
if fro:
|
549 |
+
reduce_dim = 3
|
550 |
+
else:
|
551 |
+
reduce_dim = 4
|
552 |
+
|
553 |
+
if valids is None:
|
554 |
+
feats = torch.mean(feats, dim=reduce_dim)
|
555 |
+
else:
|
556 |
+
valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1)
|
557 |
+
feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim)
|
558 |
+
|
559 |
+
B, S, C, D, W = list(feats.size())
|
560 |
+
|
561 |
+
if not pca:
|
562 |
+
# feats leads with B x S x C
|
563 |
+
feats = torch.mean(torch.abs(feats), dim=2, keepdims=True)
|
564 |
+
# feats leads with B x S x 1
|
565 |
+
feats = torch.unbind(feats, dim=1)
|
566 |
+
return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
567 |
+
|
568 |
+
else:
|
569 |
+
__p = lambda x: utils.basic.pack_seqdim(x, B)
|
570 |
+
__u = lambda x: utils.basic.unpack_seqdim(x, B)
|
571 |
+
|
572 |
+
feats_ = __p(feats)
|
573 |
+
|
574 |
+
if valids is None:
|
575 |
+
feats_pca_ = get_feat_pca(feats_)
|
576 |
+
else:
|
577 |
+
valids_ = __p(valids)
|
578 |
+
feats_pca_ = get_feat_pca(feats_, valids)
|
579 |
+
|
580 |
+
feats_pca = __u(feats_pca_)
|
581 |
+
|
582 |
+
return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
583 |
+
|
584 |
+
def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None, frame_str=None):
|
585 |
+
if self.save_this:
|
586 |
+
if feat.ndim==5: # B x C x D x H x W
|
587 |
+
|
588 |
+
if bev:
|
589 |
+
reduce_axis = 3
|
590 |
+
elif fro:
|
591 |
+
reduce_axis = 2
|
592 |
+
else:
|
593 |
+
# default to bev
|
594 |
+
reduce_axis = 3
|
595 |
+
|
596 |
+
if valid is None:
|
597 |
+
feat = torch.mean(feat, dim=reduce_axis)
|
598 |
+
else:
|
599 |
+
valid = valid.repeat(1, feat.size()[1], 1, 1, 1)
|
600 |
+
feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis)
|
601 |
+
|
602 |
+
B, C, D, W = list(feat.shape)
|
603 |
+
|
604 |
+
if not pca:
|
605 |
+
feat = torch.mean(torch.abs(feat), dim=1, keepdims=True)
|
606 |
+
# feat is B x 1 x D x W
|
607 |
+
return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
|
608 |
+
else:
|
609 |
+
feat_pca = get_feat_pca(feat, valid)
|
610 |
+
return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
|
611 |
+
|
612 |
+
def summ_scalar(self, name, value):
|
613 |
+
if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ('torch' in value.type()):
|
614 |
+
value = value.detach().cpu().numpy()
|
615 |
+
if not np.isnan(value):
|
616 |
+
if (self.log_freq == 1):
|
617 |
+
self.writer.add_scalar(name, value, global_step=self.global_step)
|
618 |
+
elif self.save_this or self.save_scalar:
|
619 |
+
self.writer.add_scalar(name, value, global_step=self.global_step)
|
620 |
+
|
621 |
+
def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap='coolwarm', vals=None, linewidth=1, max_show=1024):
|
622 |
+
# trajs is B, S, N, 2
|
623 |
+
# rgbs is B, S, C, H, W
|
624 |
+
B, S, C, H, W = rgbs.shape
|
625 |
+
B, S2, N, D = trajs.shape
|
626 |
+
assert(S==S2)
|
627 |
+
|
628 |
+
|
629 |
+
rgbs = rgbs[0] # S, C, H, W
|
630 |
+
trajs = trajs[0] # S, N, 2
|
631 |
+
if valids is None:
|
632 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
633 |
+
else:
|
634 |
+
valids = valids[0]
|
635 |
+
|
636 |
+
if visibs is None:
|
637 |
+
visibs = torch.ones_like(trajs[:,:,0]) # S, N
|
638 |
+
else:
|
639 |
+
visibs = visibs[0]
|
640 |
+
|
641 |
+
if vals is not None:
|
642 |
+
vals = vals[0] # N
|
643 |
+
# print('vals', vals.shape)
|
644 |
+
|
645 |
+
if N > max_show:
|
646 |
+
inds = np.random.choice(N, max_show)
|
647 |
+
trajs = trajs[:,inds]
|
648 |
+
valids = valids[:,inds]
|
649 |
+
visibs = visibs[:,inds]
|
650 |
+
if vals is not None:
|
651 |
+
vals = vals[inds]
|
652 |
+
N = trajs.shape[1]
|
653 |
+
|
654 |
+
trajs = trajs.clamp(-16, W+16)
|
655 |
+
|
656 |
+
rgbs_color = []
|
657 |
+
for rgb in rgbs:
|
658 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
659 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
660 |
+
rgbs_color.append(rgb) # each element 3 x H x W
|
661 |
+
|
662 |
+
for i in range(min(N, max_show)):
|
663 |
+
if cmap=='onediff' and i==0:
|
664 |
+
cmap_ = 'spring'
|
665 |
+
elif cmap=='onediff':
|
666 |
+
cmap_ = 'winter'
|
667 |
+
else:
|
668 |
+
cmap_ = cmap
|
669 |
+
traj = trajs[:,i].long().detach().cpu().numpy() # S, 2
|
670 |
+
valid = valids[:,i].long().detach().cpu().numpy() # S
|
671 |
+
|
672 |
+
# print('traj', traj.shape)
|
673 |
+
# print('valid', valid.shape)
|
674 |
+
|
675 |
+
if vals is not None:
|
676 |
+
# val = vals[:,i].float().detach().cpu().numpy() # []
|
677 |
+
val = vals[i].float().detach().cpu().numpy() # []
|
678 |
+
# print('val', val.shape)
|
679 |
+
else:
|
680 |
+
val = None
|
681 |
+
|
682 |
+
for t in range(S):
|
683 |
+
if valid[t]:
|
684 |
+
rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj[:t+1], S=S, show_dots=show_dots, cmap=cmap_, val=val, linewidth=linewidth)
|
685 |
+
|
686 |
+
for i in range(min(N, max_show)):
|
687 |
+
if cmap=='onediff' and i==0:
|
688 |
+
cmap_ = 'spring'
|
689 |
+
elif cmap=='onediff':
|
690 |
+
cmap_ = 'winter'
|
691 |
+
else:
|
692 |
+
cmap_ = cmap
|
693 |
+
traj = trajs[:,i] # S,2
|
694 |
+
vis = visibs[:,i].round() # S
|
695 |
+
valid = valids[:,i] # S
|
696 |
+
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
|
697 |
+
|
698 |
+
rgbs = []
|
699 |
+
for rgb in rgbs_color:
|
700 |
+
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
701 |
+
rgbs.append(preprocess_color(rgb))
|
702 |
+
|
703 |
+
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
704 |
+
|
705 |
+
def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap=None, linewidth=1, max_show=1024):
|
706 |
+
# trajs is B, S, N, 2
|
707 |
+
# rgbs is B, S, C, H, W
|
708 |
+
B, S, C, H, W = rgbs.shape
|
709 |
+
B, S2, N, D = trajs.shape
|
710 |
+
assert(S==S2)
|
711 |
+
|
712 |
+
rgbs = rgbs[0] # S, C, H, W
|
713 |
+
trajs = trajs[0] # S, N, 2
|
714 |
+
visibles = visibles[0] # S, N
|
715 |
+
if valids is None:
|
716 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
717 |
+
else:
|
718 |
+
valids = valids[0]
|
719 |
+
|
720 |
+
rgbs_color = []
|
721 |
+
for rgb in rgbs:
|
722 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
723 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
724 |
+
rgbs_color.append(rgb) # each element 3 x H x W
|
725 |
+
|
726 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
727 |
+
visibles = visibles.float().detach().cpu().numpy() # S, N
|
728 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
729 |
+
|
730 |
+
for i in range(min(N, max_show)):
|
731 |
+
if cmap=='onediff' and i==0:
|
732 |
+
cmap_ = 'spring'
|
733 |
+
elif cmap=='onediff':
|
734 |
+
cmap_ = 'winter'
|
735 |
+
else:
|
736 |
+
cmap_ = cmap
|
737 |
+
traj = trajs[:,i] # S,2
|
738 |
+
vis = visibles[:,i] # S
|
739 |
+
valid = valids[:,i] # S
|
740 |
+
rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
|
741 |
+
|
742 |
+
for i in range(min(N, max_show)):
|
743 |
+
if cmap=='onediff' and i==0:
|
744 |
+
cmap_ = 'spring'
|
745 |
+
elif cmap=='onediff':
|
746 |
+
cmap_ = 'winter'
|
747 |
+
else:
|
748 |
+
cmap_ = cmap
|
749 |
+
traj = trajs[:,i] # S,2
|
750 |
+
vis = visibles[:,i] # S
|
751 |
+
valid = valids[:,i] # S
|
752 |
+
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth)
|
753 |
+
|
754 |
+
rgbs = []
|
755 |
+
for rgb in rgbs_color:
|
756 |
+
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
757 |
+
rgbs.append(preprocess_color(rgb))
|
758 |
+
|
759 |
+
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
760 |
+
|
761 |
+
def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=True, show_lines=True, frame_id=None, frame_str=None, only_return=False, cmap='coolwarm', linewidth=1, max_show=1024):
|
762 |
+
# trajs is B, S, N, 2
|
763 |
+
# rgb is B, C, H, W
|
764 |
+
B, C, H, W = rgb.shape
|
765 |
+
B, S, N, D = trajs.shape
|
766 |
+
|
767 |
+
rgb = rgb[0] # S, C, H, W
|
768 |
+
trajs = trajs[0] # S, N, 2
|
769 |
+
|
770 |
+
if valids is None:
|
771 |
+
valids = torch.ones_like(trajs[:,:,0])
|
772 |
+
else:
|
773 |
+
valids = valids[0]
|
774 |
+
|
775 |
+
rgb_color = back2color(rgb).detach().cpu().numpy()
|
776 |
+
rgb_color = np.transpose(rgb_color, [1, 2, 0]) # put channels last
|
777 |
+
|
778 |
+
# using maxdist will dampen the colors for short motions
|
779 |
+
# norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0])**2, dim=1)) # N
|
780 |
+
# maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy()
|
781 |
+
maxdist = None
|
782 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
783 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
784 |
+
|
785 |
+
if N > max_show:
|
786 |
+
inds = np.random.choice(N, max_show)
|
787 |
+
trajs = trajs[:,inds]
|
788 |
+
valids = valids[:,inds]
|
789 |
+
N = trajs.shape[1]
|
790 |
+
|
791 |
+
for i in range(min(N, max_show)):
|
792 |
+
if cmap=='onediff' and i==0:
|
793 |
+
cmap_ = 'spring'
|
794 |
+
elif cmap=='onediff':
|
795 |
+
cmap_ = 'winter'
|
796 |
+
else:
|
797 |
+
cmap_ = cmap
|
798 |
+
traj = trajs[:,i] # S, 2
|
799 |
+
valid = valids[:,i] # S
|
800 |
+
if valid[0]==1:
|
801 |
+
traj = traj[valid>0]
|
802 |
+
rgb_color = self.draw_traj_on_image_py(
|
803 |
+
rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth)
|
804 |
+
|
805 |
+
rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0)
|
806 |
+
rgb = preprocess_color(rgb_color)
|
807 |
+
return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
|
808 |
+
|
809 |
+
def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None):
|
810 |
+
# all inputs are numpy tensors
|
811 |
+
# rgb is 3 x H x W
|
812 |
+
# traj is S x 2
|
813 |
+
|
814 |
+
H, W, C = rgb.shape
|
815 |
+
assert(C==3)
|
816 |
+
|
817 |
+
rgb = rgb.astype(np.uint8).copy()
|
818 |
+
|
819 |
+
S1, D = traj.shape
|
820 |
+
assert(D==2)
|
821 |
+
|
822 |
+
color_map = cm.get_cmap(cmap)
|
823 |
+
S1, D = traj.shape
|
824 |
+
|
825 |
+
for s in range(S1):
|
826 |
+
if val is not None:
|
827 |
+
color = np.array(color_map(val)[:3]) * 255 # rgb
|
828 |
+
else:
|
829 |
+
if maxdist is not None:
|
830 |
+
val = (np.sqrt(np.sum((traj[s]-traj[0])**2))/maxdist).clip(0,1)
|
831 |
+
color = np.array(color_map(val)[:3]) * 255 # rgb
|
832 |
+
else:
|
833 |
+
color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
|
834 |
+
|
835 |
+
if show_lines and s<(S1-1):
|
836 |
+
cv2.line(rgb,
|
837 |
+
(int(traj[s,0]), int(traj[s,1])),
|
838 |
+
(int(traj[s+1,0]), int(traj[s+1,1])),
|
839 |
+
color,
|
840 |
+
linewidth,
|
841 |
+
cv2.LINE_AA)
|
842 |
+
if show_dots:
|
843 |
+
cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, -1)
|
844 |
+
|
845 |
+
# if maxdist is not None:
|
846 |
+
# val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1)
|
847 |
+
# color = np.array(color_map(val)[:3]) * 255 # rgb
|
848 |
+
# else:
|
849 |
+
# # draw the endpoint of traj, using the next color (which may be the last color)
|
850 |
+
# color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb
|
851 |
+
|
852 |
+
# # emphasize endpoint
|
853 |
+
# cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1)
|
854 |
+
|
855 |
+
return rgb
|
856 |
+
|
857 |
+
|
858 |
+
def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None):
|
859 |
+
# all inputs are numpy tensors
|
860 |
+
# rgbs is a list of H,W,3
|
861 |
+
# traj is S,2
|
862 |
+
H, W, C = rgbs[0].shape
|
863 |
+
assert(C==3)
|
864 |
+
|
865 |
+
rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
|
866 |
+
|
867 |
+
S1, D = traj.shape
|
868 |
+
assert(D==2)
|
869 |
+
|
870 |
+
x = int(np.clip(traj[0,0], 0, W-1))
|
871 |
+
y = int(np.clip(traj[0,1], 0, H-1))
|
872 |
+
color = rgbs[0][y,x]
|
873 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
874 |
+
for s in range(S):
|
875 |
+
# bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb
|
876 |
+
# cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1)
|
877 |
+
cv2.polylines(rgbs[s],
|
878 |
+
[traj[:s+1]],
|
879 |
+
False,
|
880 |
+
color,
|
881 |
+
linewidth,
|
882 |
+
cv2.LINE_AA)
|
883 |
+
return rgbs
|
884 |
+
|
885 |
+
def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None):
|
886 |
+
# all inputs are numpy tensors
|
887 |
+
# rgbs is a list of 3,H,W
|
888 |
+
# xy is N,2
|
889 |
+
H, W, C = rgb.shape
|
890 |
+
assert(C==3)
|
891 |
+
|
892 |
+
rgb = rgb.astype(np.uint8).copy()
|
893 |
+
|
894 |
+
N, D = xy.shape
|
895 |
+
assert(D==2)
|
896 |
+
|
897 |
+
|
898 |
+
xy = xy.astype(np.float32)
|
899 |
+
xy[:,0] = np.clip(xy[:,0], 0, W-1)
|
900 |
+
xy[:,1] = np.clip(xy[:,1], 0, H-1)
|
901 |
+
xy = xy.astype(np.int32)
|
902 |
+
|
903 |
+
|
904 |
+
|
905 |
+
if colors is None:
|
906 |
+
colors = get_n_colors(N)
|
907 |
+
|
908 |
+
for n in range(N):
|
909 |
+
color = colors[n]
|
910 |
+
# print('color', color)
|
911 |
+
# color = (color[0]*255).astype(np.uint8)
|
912 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
913 |
+
|
914 |
+
# x = int(np.clip(xy[0,0], 0, W-1))
|
915 |
+
# y = int(np.clip(xy[0,1], 0, H-1))
|
916 |
+
# color_ = rgbs[0][y,x]
|
917 |
+
# color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
|
918 |
+
# color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
|
919 |
+
|
920 |
+
cv2.circle(rgb, (int(xy[n,0]), int(xy[n,1])), linewidth, color, 3)
|
921 |
+
# vis_color = int(np.squeeze(vis[s])*255)
|
922 |
+
# vis_color = (vis_color,vis_color,vis_color)
|
923 |
+
# cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1)
|
924 |
+
return rgb
|
925 |
+
|
926 |
+
def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None):
|
927 |
+
# all inputs are numpy tensors
|
928 |
+
# rgbs is a list of 3,H,W
|
929 |
+
# traj is S,2
|
930 |
+
H, W, C = rgbs[0].shape
|
931 |
+
assert(C==3)
|
932 |
+
|
933 |
+
rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
|
934 |
+
|
935 |
+
S1, D = traj.shape
|
936 |
+
assert(D==2)
|
937 |
+
|
938 |
+
if cmap is None:
|
939 |
+
bremm = ColorMap2d()
|
940 |
+
traj_ = traj[0:1].astype(np.float32)
|
941 |
+
traj_[:,0] /= float(W)
|
942 |
+
traj_[:,1] /= float(H)
|
943 |
+
color = bremm(traj_)
|
944 |
+
# print('color', color)
|
945 |
+
color = (color[0]*255).astype(np.uint8)
|
946 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
947 |
+
|
948 |
+
for s in range(S):
|
949 |
+
if cmap is not None:
|
950 |
+
color_map = cm.get_cmap(cmap)
|
951 |
+
# color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb
|
952 |
+
color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
|
953 |
+
# color = color.astype(np.uint8)
|
954 |
+
# color = (color[0], color[1], color[2])
|
955 |
+
# print('color', color)
|
956 |
+
# import ipdb; ipdb.set_trace()
|
957 |
+
|
958 |
+
cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+2, color, -1)
|
959 |
+
vis_color = int(np.squeeze(vis[s])*255)
|
960 |
+
vis_color = (vis_color,vis_color,vis_color)
|
961 |
+
cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1)
|
962 |
+
|
963 |
+
return rgbs
|
964 |
+
|
965 |
+
def summ_pts_on_rgb(self, name, trajs, rgb, visibs=None, valids=None, frame_id=None, frame_str=None, only_return=False, show_dots=True, colors=None, cmap='coolwarm', linewidth=1, max_show=1024, already_sorted=False):
|
966 |
+
# trajs is B, S, N, 2
|
967 |
+
# rgbs is B, S, C, H, W
|
968 |
+
B, C, H, W = rgb.shape
|
969 |
+
B, S, N, D = trajs.shape
|
970 |
+
|
971 |
+
rgb = rgb[0] # C, H, W
|
972 |
+
trajs = trajs[0] # S, N, 2
|
973 |
+
if valids is None:
|
974 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
975 |
+
else:
|
976 |
+
valids = valids[0]
|
977 |
+
if visibs is None:
|
978 |
+
visibs = torch.ones_like(trajs[:,:,0]) # S, N
|
979 |
+
else:
|
980 |
+
visibs = visibs[0]
|
981 |
+
|
982 |
+
trajs = trajs.clamp(-16, W+16)
|
983 |
+
|
984 |
+
if N > max_show:
|
985 |
+
inds = np.random.choice(N, max_show)
|
986 |
+
trajs = trajs[:,inds]
|
987 |
+
valids = valids[:,inds]
|
988 |
+
visibs = visibs[:,inds]
|
989 |
+
N = trajs.shape[1]
|
990 |
+
|
991 |
+
if not already_sorted:
|
992 |
+
inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0))
|
993 |
+
trajs = trajs[:,inds]
|
994 |
+
valids = valids[:,inds]
|
995 |
+
visibs = visibs[:,inds]
|
996 |
+
|
997 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
998 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
999 |
+
|
1000 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
1001 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
1002 |
+
visibs = visibs.long().detach().cpu().numpy() # S, N
|
1003 |
+
|
1004 |
+
rgb = rgb.astype(np.uint8).copy()
|
1005 |
+
|
1006 |
+
for i in range(min(N, max_show)):
|
1007 |
+
if cmap=='onediff' and i==0:
|
1008 |
+
cmap_ = 'spring'
|
1009 |
+
elif cmap=='onediff':
|
1010 |
+
cmap_ = 'winter'
|
1011 |
+
else:
|
1012 |
+
cmap_ = cmap
|
1013 |
+
traj = trajs[:,i] # S,2
|
1014 |
+
valid = valids[:,i] # S
|
1015 |
+
visib = visibs[:,i] # S
|
1016 |
+
|
1017 |
+
if colors is None:
|
1018 |
+
ii = i/(1e-4+N-1.0)
|
1019 |
+
color_map = cm.get_cmap(cmap)
|
1020 |
+
color = np.array(color_map(ii)[:3]) * 255 # rgb
|
1021 |
+
else:
|
1022 |
+
color = np.array(colors[i]).astype(np.int64)
|
1023 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
1024 |
+
|
1025 |
+
for s in range(S):
|
1026 |
+
if valid[s]:
|
1027 |
+
if visib[s]:
|
1028 |
+
thickness = -1
|
1029 |
+
else:
|
1030 |
+
thickness = 2
|
1031 |
+
cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, thickness)
|
1032 |
+
rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0)
|
1033 |
+
rgb = preprocess_color(rgb)
|
1034 |
+
return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
|
1035 |
+
|
1036 |
+
def summ_pts_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', colors=None, linewidth=1, max_show=1024, frame_strs=None):
|
1037 |
+
# trajs is B, S, N, 2
|
1038 |
+
# rgbs is B, S, C, H, W
|
1039 |
+
B, S, C, H, W = rgbs.shape
|
1040 |
+
B, S2, N, D = trajs.shape
|
1041 |
+
assert(S==S2)
|
1042 |
+
|
1043 |
+
rgbs = rgbs[0] # S, C, H, W
|
1044 |
+
trajs = trajs[0] # S, N, 2
|
1045 |
+
if valids is None:
|
1046 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
1047 |
+
else:
|
1048 |
+
valids = valids[0]
|
1049 |
+
if visibs is None:
|
1050 |
+
visibs = torch.ones_like(trajs[:,:,0]) # S, N
|
1051 |
+
else:
|
1052 |
+
visibs = visibs[0]
|
1053 |
+
|
1054 |
+
if N > max_show:
|
1055 |
+
inds = np.random.choice(N, max_show)
|
1056 |
+
trajs = trajs[:,inds]
|
1057 |
+
valids = valids[:,inds]
|
1058 |
+
visibs = visibs[:,inds]
|
1059 |
+
N = trajs.shape[1]
|
1060 |
+
inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0))
|
1061 |
+
trajs = trajs[:,inds]
|
1062 |
+
valids = valids[:,inds]
|
1063 |
+
visibs = visibs[:,inds]
|
1064 |
+
|
1065 |
+
rgbs_color = []
|
1066 |
+
for rgb in rgbs:
|
1067 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
1068 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
1069 |
+
rgbs_color.append(rgb) # each element 3 x H x W
|
1070 |
+
|
1071 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
1072 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
1073 |
+
visibs = visibs.long().detach().cpu().numpy() # S, N
|
1074 |
+
|
1075 |
+
rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color]
|
1076 |
+
|
1077 |
+
for i in range(min(N, max_show)):
|
1078 |
+
traj = trajs[:,i] # S,2
|
1079 |
+
valid = valids[:,i] # S
|
1080 |
+
visib = visibs[:,i] # S
|
1081 |
+
|
1082 |
+
if colors is None:
|
1083 |
+
ii = i/(1e-4+N-1.0)
|
1084 |
+
color_map = cm.get_cmap(cmap)
|
1085 |
+
color = np.array(color_map(ii)[:3]) * 255 # rgb
|
1086 |
+
else:
|
1087 |
+
color = np.array(colors[i]).astype(np.int64)
|
1088 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
1089 |
+
|
1090 |
+
for s in range(S):
|
1091 |
+
if valid[s]:
|
1092 |
+
if visib[s]:
|
1093 |
+
thickness = -1
|
1094 |
+
else:
|
1095 |
+
thickness = 2
|
1096 |
+
cv2.circle(rgbs_color[s], (int(traj[s,0]), int(traj[s,1])), int(linewidth), color, thickness)
|
1097 |
+
rgbs = []
|
1098 |
+
for rgb in rgbs_color:
|
1099 |
+
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
1100 |
+
rgbs.append(preprocess_color(rgb))
|
1101 |
+
|
1102 |
+
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
1103 |
+
|
utils/loss.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from typing import List
|
5 |
+
import utils.basic
|
6 |
+
|
7 |
+
|
8 |
+
def sequence_loss(
|
9 |
+
flow_preds,
|
10 |
+
flow_gt,
|
11 |
+
valids,
|
12 |
+
vis=None,
|
13 |
+
gamma=0.8,
|
14 |
+
use_huber_loss=False,
|
15 |
+
loss_only_for_visible=False,
|
16 |
+
):
|
17 |
+
"""Loss function defined over sequence of flow predictions"""
|
18 |
+
total_flow_loss = 0.0
|
19 |
+
for j in range(len(flow_gt)):
|
20 |
+
B, S, N, D = flow_gt[j].shape
|
21 |
+
B, S2, N = valids[j].shape
|
22 |
+
assert S == S2
|
23 |
+
n_predictions = len(flow_preds[j])
|
24 |
+
flow_loss = 0.0
|
25 |
+
for i in range(n_predictions):
|
26 |
+
i_weight = gamma ** (n_predictions - i - 1)
|
27 |
+
flow_pred = flow_preds[j][i]
|
28 |
+
if use_huber_loss:
|
29 |
+
i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0)
|
30 |
+
else:
|
31 |
+
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
|
32 |
+
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
33 |
+
valid_ = valids[j].clone()
|
34 |
+
if loss_only_for_visible:
|
35 |
+
valid_ = valid_ * vis[j]
|
36 |
+
flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss, valid_)
|
37 |
+
flow_loss = flow_loss / n_predictions
|
38 |
+
total_flow_loss += flow_loss
|
39 |
+
return total_flow_loss / len(flow_gt)
|
40 |
+
|
41 |
+
def sequence_loss_dense(
|
42 |
+
flow_preds,
|
43 |
+
flow_gt,
|
44 |
+
valids,
|
45 |
+
vis=None,
|
46 |
+
gamma=0.8,
|
47 |
+
use_huber_loss=False,
|
48 |
+
loss_only_for_visible=False,
|
49 |
+
):
|
50 |
+
"""Loss function defined over sequence of flow predictions"""
|
51 |
+
total_flow_loss = 0.0
|
52 |
+
for j in range(len(flow_gt)):
|
53 |
+
# print('flow_gt[j]', flow_gt[j].shape)
|
54 |
+
B, S, D, H, W = flow_gt[j].shape
|
55 |
+
B, S2, _, H, W = valids[j].shape
|
56 |
+
assert S == S2
|
57 |
+
n_predictions = len(flow_preds[j])
|
58 |
+
flow_loss = 0.0
|
59 |
+
# import ipdb; ipdb.set_trace()
|
60 |
+
for i in range(n_predictions):
|
61 |
+
# print('flow_e[j][i]', flow_preds[j][i].shape)
|
62 |
+
i_weight = gamma ** (n_predictions - i - 1)
|
63 |
+
flow_pred = flow_preds[j][i] # B,S,2,H,W
|
64 |
+
if use_huber_loss:
|
65 |
+
i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0) # B,S,2,H,W
|
66 |
+
else:
|
67 |
+
i_loss = (flow_pred - flow_gt[j]).abs() # B,S,2,H,W
|
68 |
+
i_loss_ = torch.mean(i_loss, dim=2) # B,S,H,W
|
69 |
+
valid_ = valids[j].reshape(B,S,H,W)
|
70 |
+
# print(' (%d,%d) i_loss_' % (i,j), i_loss_.shape)
|
71 |
+
# print(' (%d,%d) valid_' % (i,j), valid_.shape)
|
72 |
+
if loss_only_for_visible:
|
73 |
+
valid_ = valid_ * vis[j].reshape(B,-1,H,W) # usually B,S,H,W, but maybe B,1,H,W
|
74 |
+
flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss_, valid_, broadcast=True)
|
75 |
+
# import ipdb; ipdb.set_trace()
|
76 |
+
flow_loss = flow_loss / n_predictions
|
77 |
+
total_flow_loss += flow_loss
|
78 |
+
return total_flow_loss / len(flow_gt)
|
79 |
+
|
80 |
+
|
81 |
+
def huber_loss(x, y, delta=1.0):
|
82 |
+
"""Calculate element-wise Huber loss between x and y"""
|
83 |
+
diff = x - y
|
84 |
+
abs_diff = diff.abs()
|
85 |
+
flag = (abs_diff <= delta).float()
|
86 |
+
return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)
|
87 |
+
|
88 |
+
|
89 |
+
def sequence_BCE_loss(vis_preds, vis_gts, valids=None, use_logits=False):
|
90 |
+
total_bce_loss = 0.0
|
91 |
+
# all_vis_preds = [torch.stack(vp) for vp in vis_preds]
|
92 |
+
# all_vis_preds = torch.stack(all_vis_preds)
|
93 |
+
# utils.basic.print_stats('all_vis_preds', all_vis_preds)
|
94 |
+
for j in range(len(vis_preds)):
|
95 |
+
n_predictions = len(vis_preds[j])
|
96 |
+
bce_loss = 0.0
|
97 |
+
for i in range(n_predictions):
|
98 |
+
# utils.basic.print_stats('vis_preds[%d][%d]' % (j,i), vis_preds[j][i])
|
99 |
+
# utils.basic.print_stats('vis_gts[%d]' % (i), vis_gts[i])
|
100 |
+
if use_logits:
|
101 |
+
loss = F.binary_cross_entropy_with_logits(vis_preds[j][i], vis_gts[j], reduction='none')
|
102 |
+
else:
|
103 |
+
loss = F.binary_cross_entropy(vis_preds[j][i], vis_gts[j], reduction='none')
|
104 |
+
if valids is None:
|
105 |
+
bce_loss += loss.mean()
|
106 |
+
else:
|
107 |
+
bce_loss += (loss * valids[j]).mean()
|
108 |
+
bce_loss = bce_loss / n_predictions
|
109 |
+
total_bce_loss += bce_loss
|
110 |
+
return total_bce_loss / len(vis_preds)
|
111 |
+
|
112 |
+
|
113 |
+
# def sequence_BCE_loss_dense(vis_preds, vis_gts):
|
114 |
+
# total_bce_loss = 0.0
|
115 |
+
# for j in range(len(vis_preds)):
|
116 |
+
# n_predictions = len(vis_preds[j])
|
117 |
+
# bce_loss = 0.0
|
118 |
+
# for i in range(n_predictions):
|
119 |
+
# vis_e = vis_preds[j][i]
|
120 |
+
# vis_g = vis_gts[j]
|
121 |
+
# print('vis_e', vis_e.shape, 'vis_g', vis_g.shape)
|
122 |
+
# vis_loss = F.binary_cross_entropy(vis_e, vis_g)
|
123 |
+
# bce_loss += vis_loss
|
124 |
+
# bce_loss = bce_loss / n_predictions
|
125 |
+
# total_bce_loss += bce_loss
|
126 |
+
# return total_bce_loss / len(vis_preds)
|
127 |
+
|
128 |
+
|
129 |
+
def sequence_prob_loss(
|
130 |
+
tracks: torch.Tensor,
|
131 |
+
confidence: torch.Tensor,
|
132 |
+
target_points: torch.Tensor,
|
133 |
+
visibility: torch.Tensor,
|
134 |
+
expected_dist_thresh: float = 12.0,
|
135 |
+
use_logits=False,
|
136 |
+
):
|
137 |
+
"""Loss for classifying if a point is within pixel threshold of its target."""
|
138 |
+
# Points with an error larger than 12 pixels are likely to be useless; marking
|
139 |
+
# them as occluded will actually improve Jaccard metrics and give
|
140 |
+
# qualitatively better results.
|
141 |
+
total_logprob_loss = 0.0
|
142 |
+
for j in range(len(tracks)):
|
143 |
+
n_predictions = len(tracks[j])
|
144 |
+
logprob_loss = 0.0
|
145 |
+
for i in range(n_predictions):
|
146 |
+
err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=-1)
|
147 |
+
valid = (err <= expected_dist_thresh**2).float()
|
148 |
+
if use_logits:
|
149 |
+
loss = F.binary_cross_entropy_with_logits(confidence[j][i], valid, reduction="none")
|
150 |
+
else:
|
151 |
+
loss = F.binary_cross_entropy(confidence[j][i], valid, reduction="none")
|
152 |
+
loss *= visibility[j]
|
153 |
+
loss = torch.mean(loss, dim=[1, 2])
|
154 |
+
logprob_loss += loss
|
155 |
+
logprob_loss = logprob_loss / n_predictions
|
156 |
+
total_logprob_loss += logprob_loss
|
157 |
+
return total_logprob_loss / len(tracks)
|
158 |
+
|
159 |
+
def sequence_prob_loss_dense(
|
160 |
+
tracks: torch.Tensor,
|
161 |
+
confidence: torch.Tensor,
|
162 |
+
target_points: torch.Tensor,
|
163 |
+
visibility: torch.Tensor,
|
164 |
+
expected_dist_thresh: float = 12.0,
|
165 |
+
use_logits=False,
|
166 |
+
):
|
167 |
+
"""Loss for classifying if a point is within pixel threshold of its target."""
|
168 |
+
# Points with an error larger than 12 pixels are likely to be useless; marking
|
169 |
+
# them as occluded will actually improve Jaccard metrics and give
|
170 |
+
# qualitatively better results.
|
171 |
+
|
172 |
+
# all_confidence = [torch.stack(vp) for vp in confidence]
|
173 |
+
# all_confidence = torch.stack(all_confidence)
|
174 |
+
# utils.basic.print_stats('all_confidence', all_confidence)
|
175 |
+
|
176 |
+
total_logprob_loss = 0.0
|
177 |
+
for j in range(len(tracks)):
|
178 |
+
n_predictions = len(tracks[j])
|
179 |
+
logprob_loss = 0.0
|
180 |
+
for i in range(n_predictions):
|
181 |
+
# print('trajs_e', tracks[j][i].shape)
|
182 |
+
# print('trajs_g', target_points[j].shape)
|
183 |
+
err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=2)
|
184 |
+
positive = (err <= expected_dist_thresh**2).float()
|
185 |
+
# print('conf', confidence[j][i].shape, 'positive', positive.shape)
|
186 |
+
if use_logits:
|
187 |
+
loss = F.binary_cross_entropy_with_logits(confidence[j][i].squeeze(2), positive, reduction="none")
|
188 |
+
else:
|
189 |
+
loss = F.binary_cross_entropy(confidence[j][i].squeeze(2), positive, reduction="none")
|
190 |
+
loss *= visibility[j].squeeze(2) # B,S,H,W
|
191 |
+
loss = torch.mean(loss, dim=[1,2,3])
|
192 |
+
logprob_loss += loss
|
193 |
+
logprob_loss = logprob_loss / n_predictions
|
194 |
+
total_logprob_loss += logprob_loss
|
195 |
+
return total_logprob_loss / len(tracks)
|
196 |
+
|
197 |
+
|
198 |
+
def masked_mean(data, mask, dim):
|
199 |
+
if mask is None:
|
200 |
+
return data.mean(dim=dim, keepdim=True)
|
201 |
+
mask = mask.float()
|
202 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
203 |
+
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
|
204 |
+
mask_sum, min=1.0
|
205 |
+
)
|
206 |
+
return mask_mean
|
207 |
+
|
208 |
+
|
209 |
+
def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
|
210 |
+
if mask is None:
|
211 |
+
return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
|
212 |
+
mask = mask.float()
|
213 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
214 |
+
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
|
215 |
+
mask_sum, min=1.0
|
216 |
+
)
|
217 |
+
mask_var = torch.sum(
|
218 |
+
mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
|
219 |
+
) / torch.clamp(mask_sum, min=1.0)
|
220 |
+
return mask_mean.squeeze(dim), mask_var.squeeze(dim)
|
utils/misc.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, positions):
|
5 |
+
assert embed_dim % 2 == 0
|
6 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
7 |
+
omega /= embed_dim / 2.0
|
8 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
9 |
+
|
10 |
+
positions = positions.reshape(-1) # (M,)
|
11 |
+
out = torch.einsum("m,d->md", positions, omega) # (M, D/2), outer product
|
12 |
+
|
13 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
14 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
15 |
+
|
16 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
17 |
+
return emb[None].float()
|
18 |
+
|
19 |
+
|
20 |
+
class SimplePool():
|
21 |
+
def __init__(self, pool_size, version='pt', min_size=1):
|
22 |
+
self.pool_size = pool_size
|
23 |
+
self.version = version
|
24 |
+
self.items = []
|
25 |
+
self.min_size = min_size
|
26 |
+
|
27 |
+
if not (version=='pt' or version=='np'):
|
28 |
+
print('version = %s; please choose pt or np')
|
29 |
+
assert(False) # please choose pt or np
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
return len(self.items)
|
33 |
+
|
34 |
+
def mean(self, min_size=None):
|
35 |
+
if min_size is None:
|
36 |
+
pool_size_thresh = self.min_size
|
37 |
+
elif min_size=='half':
|
38 |
+
pool_size_thresh = self.pool_size/2
|
39 |
+
else:
|
40 |
+
pool_size_thresh = min_size
|
41 |
+
|
42 |
+
if self.version=='np':
|
43 |
+
if len(self.items) >= pool_size_thresh:
|
44 |
+
return np.sum(self.items)/float(len(self.items))
|
45 |
+
else:
|
46 |
+
return np.nan
|
47 |
+
if self.version=='pt':
|
48 |
+
if len(self.items) >= pool_size_thresh:
|
49 |
+
return torch.sum(self.items)/float(len(self.items))
|
50 |
+
else:
|
51 |
+
return torch.from_numpy(np.nan)
|
52 |
+
|
53 |
+
def sample(self, with_replacement=True):
|
54 |
+
idx = np.random.randint(len(self.items))
|
55 |
+
if with_replacement:
|
56 |
+
return self.items[idx]
|
57 |
+
else:
|
58 |
+
return self.items.pop(idx)
|
59 |
+
|
60 |
+
def fetch(self, num=None):
|
61 |
+
if self.version=='pt':
|
62 |
+
item_array = torch.stack(self.items)
|
63 |
+
elif self.version=='np':
|
64 |
+
item_array = np.stack(self.items)
|
65 |
+
if num is not None:
|
66 |
+
# there better be some items
|
67 |
+
assert(len(self.items) >= num)
|
68 |
+
|
69 |
+
# if there are not that many elements just return however many there are
|
70 |
+
if len(self.items) < num:
|
71 |
+
return item_array
|
72 |
+
else:
|
73 |
+
idxs = np.random.randint(len(self.items), size=num)
|
74 |
+
return item_array[idxs]
|
75 |
+
else:
|
76 |
+
return item_array
|
77 |
+
|
78 |
+
def is_full(self):
|
79 |
+
full = len(self.items)==self.pool_size
|
80 |
+
return full
|
81 |
+
|
82 |
+
def empty(self):
|
83 |
+
self.items = []
|
84 |
+
|
85 |
+
def have_min_size(self):
|
86 |
+
return len(self.items) >= self.min_size
|
87 |
+
|
88 |
+
|
89 |
+
def update(self, items):
|
90 |
+
for item in items:
|
91 |
+
if len(self.items) < self.pool_size:
|
92 |
+
# the pool is not full, so let's add this in
|
93 |
+
self.items.append(item)
|
94 |
+
else:
|
95 |
+
# the pool is full
|
96 |
+
# pop from the front
|
97 |
+
self.items.pop(0)
|
98 |
+
# add to the back
|
99 |
+
self.items.append(item)
|
100 |
+
return self.items
|
utils/py.py
ADDED
@@ -0,0 +1,755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob, math
|
2 |
+
import numpy as np
|
3 |
+
# from scipy import misc
|
4 |
+
# from scipy import linalg
|
5 |
+
from PIL import Image
|
6 |
+
import io
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
EPS = 1e-6
|
9 |
+
|
10 |
+
|
11 |
+
XMIN = -64.0 # right (neg is left)
|
12 |
+
XMAX = 64.0 # right
|
13 |
+
YMIN = -64.0 # down (neg is up)
|
14 |
+
YMAX = 64.0 # down
|
15 |
+
ZMIN = -64.0 # forward
|
16 |
+
ZMAX = 64.0 # forward
|
17 |
+
|
18 |
+
def print_stats(name, tensor):
|
19 |
+
tensor = tensor.astype(np.float32)
|
20 |
+
print('%s min = %.2f, mean = %.2f, max = %.2f' % (name, np.min(tensor), np.mean(tensor), np.max(tensor)), tensor.shape)
|
21 |
+
|
22 |
+
def reduce_masked_mean(x, mask, axis=None, keepdims=False):
|
23 |
+
# x and mask are the same shape
|
24 |
+
# returns shape-1
|
25 |
+
# axis can be a list of axes
|
26 |
+
prod = x*mask
|
27 |
+
numer = np.sum(prod, axis=axis, keepdims=keepdims)
|
28 |
+
denom = EPS+np.sum(mask, axis=axis, keepdims=keepdims)
|
29 |
+
mean = numer/denom
|
30 |
+
return mean
|
31 |
+
|
32 |
+
def reduce_masked_sum(x, mask, axis=None, keepdims=False):
|
33 |
+
# x and mask are the same shape
|
34 |
+
# returns shape-1
|
35 |
+
# axis can be a list of axes
|
36 |
+
prod = x*mask
|
37 |
+
numer = np.sum(prod, axis=axis, keepdims=keepdims)
|
38 |
+
return numer
|
39 |
+
|
40 |
+
def reduce_masked_median(x, mask, keep_batch=False):
|
41 |
+
# x and mask are the same shape
|
42 |
+
# returns shape-1
|
43 |
+
# axis can be a list of axes
|
44 |
+
|
45 |
+
if not (x.shape == mask.shape):
|
46 |
+
print('reduce_masked_median: these shapes should match:', x.shape, mask.shape)
|
47 |
+
assert(False)
|
48 |
+
# assert(x.shape == mask.shape)
|
49 |
+
|
50 |
+
B = list(x.shape)[0]
|
51 |
+
|
52 |
+
if keep_batch:
|
53 |
+
x = np.reshape(x, [B, -1])
|
54 |
+
mask = np.reshape(mask, [B, -1])
|
55 |
+
meds = np.zeros([B], np.float32)
|
56 |
+
for b in list(range(B)):
|
57 |
+
xb = x[b]
|
58 |
+
mb = mask[b]
|
59 |
+
if np.sum(mb) > 0:
|
60 |
+
xb = xb[mb > 0]
|
61 |
+
meds[b] = np.median(xb)
|
62 |
+
else:
|
63 |
+
meds[b] = np.nan
|
64 |
+
return meds
|
65 |
+
else:
|
66 |
+
x = np.reshape(x, [-1])
|
67 |
+
mask = np.reshape(mask, [-1])
|
68 |
+
if np.sum(mask) > 0:
|
69 |
+
x = x[mask > 0]
|
70 |
+
med = np.median(x)
|
71 |
+
else:
|
72 |
+
med = np.nan
|
73 |
+
med = np.array([med], np.float32)
|
74 |
+
return med
|
75 |
+
|
76 |
+
def get_nFiles(path):
|
77 |
+
return len(glob.glob(path))
|
78 |
+
|
79 |
+
def get_file_list(path):
|
80 |
+
return glob.glob(path)
|
81 |
+
|
82 |
+
def rotm2eul(R):
|
83 |
+
# R is 3x3
|
84 |
+
sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0])
|
85 |
+
if sy > 1e-6: # singular
|
86 |
+
x = math.atan2(R[2,1] , R[2,2])
|
87 |
+
y = math.atan2(-R[2,0], sy)
|
88 |
+
z = math.atan2(R[1,0], R[0,0])
|
89 |
+
else:
|
90 |
+
x = math.atan2(-R[1,2], R[1,1])
|
91 |
+
y = math.atan2(-R[2,0], sy)
|
92 |
+
z = 0
|
93 |
+
return x, y, z
|
94 |
+
|
95 |
+
def rad2deg(rad):
|
96 |
+
return rad*180.0/np.pi
|
97 |
+
|
98 |
+
def deg2rad(deg):
|
99 |
+
return deg/180.0*np.pi
|
100 |
+
|
101 |
+
def eul2rotm(rx, ry, rz):
|
102 |
+
# copy of matlab, but order of inputs is different
|
103 |
+
# R = [ cy*cz sy*sx*cz-sz*cx sy*cx*cz+sz*sx
|
104 |
+
# cy*sz sy*sx*sz+cz*cx sy*cx*sz-cz*sx
|
105 |
+
# -sy cy*sx cy*cx]
|
106 |
+
sinz = np.sin(rz)
|
107 |
+
siny = np.sin(ry)
|
108 |
+
sinx = np.sin(rx)
|
109 |
+
cosz = np.cos(rz)
|
110 |
+
cosy = np.cos(ry)
|
111 |
+
cosx = np.cos(rx)
|
112 |
+
r11 = cosy*cosz
|
113 |
+
r12 = sinx*siny*cosz - cosx*sinz
|
114 |
+
r13 = cosx*siny*cosz + sinx*sinz
|
115 |
+
r21 = cosy*sinz
|
116 |
+
r22 = sinx*siny*sinz + cosx*cosz
|
117 |
+
r23 = cosx*siny*sinz - sinx*cosz
|
118 |
+
r31 = -siny
|
119 |
+
r32 = sinx*cosy
|
120 |
+
r33 = cosx*cosy
|
121 |
+
r1 = np.stack([r11,r12,r13],axis=-1)
|
122 |
+
r2 = np.stack([r21,r22,r23],axis=-1)
|
123 |
+
r3 = np.stack([r31,r32,r33],axis=-1)
|
124 |
+
r = np.stack([r1,r2,r3],axis=0)
|
125 |
+
return r
|
126 |
+
|
127 |
+
def wrap2pi(rad_angle):
|
128 |
+
# puts the angle into the range [-pi, pi]
|
129 |
+
return np.arctan2(np.sin(rad_angle), np.cos(rad_angle))
|
130 |
+
|
131 |
+
def rot2view(rx,ry,rz,x,y,z):
|
132 |
+
# takes rot angles and 3d position as input
|
133 |
+
# returns viewpoint angles as output
|
134 |
+
# (all in radians)
|
135 |
+
# it will perform strangely if z <= 0
|
136 |
+
az = wrap2pi(ry - (-np.arctan2(z, x) - 1.5*np.pi))
|
137 |
+
el = -wrap2pi(rx - (-np.arctan2(z, y) - 1.5*np.pi))
|
138 |
+
th = -rz
|
139 |
+
return az, el, th
|
140 |
+
|
141 |
+
def invAxB(a,b):
|
142 |
+
"""
|
143 |
+
Compute the relative 3d transformation between a and b.
|
144 |
+
|
145 |
+
Input:
|
146 |
+
a -- first pose (homogeneous 4x4 matrix)
|
147 |
+
b -- second pose (homogeneous 4x4 matrix)
|
148 |
+
|
149 |
+
Output:
|
150 |
+
Relative 3d transformation from a to b.
|
151 |
+
"""
|
152 |
+
return np.dot(np.linalg.inv(a),b)
|
153 |
+
|
154 |
+
def merge_rt(r, t):
|
155 |
+
# r is 3 x 3
|
156 |
+
# t is 3 or maybe 3 x 1
|
157 |
+
t = np.reshape(t, [3, 1])
|
158 |
+
rt = np.concatenate((r,t), axis=1)
|
159 |
+
# rt is 3 x 4
|
160 |
+
br = np.reshape(np.array([0,0,0,1], np.float32), [1, 4])
|
161 |
+
# br is 1 x 4
|
162 |
+
rt = np.concatenate((rt, br), axis=0)
|
163 |
+
# rt is 4 x 4
|
164 |
+
return rt
|
165 |
+
|
166 |
+
def split_rt(rt):
|
167 |
+
r = rt[:3,:3]
|
168 |
+
t = rt[:3,3]
|
169 |
+
r = np.reshape(r, [3, 3])
|
170 |
+
t = np.reshape(t, [3, 1])
|
171 |
+
return r, t
|
172 |
+
|
173 |
+
def split_intrinsics(K):
|
174 |
+
# K is 3 x 4 or 4 x 4
|
175 |
+
fx = K[0,0]
|
176 |
+
fy = K[1,1]
|
177 |
+
x0 = K[0,2]
|
178 |
+
y0 = K[1,2]
|
179 |
+
return fx, fy, x0, y0
|
180 |
+
|
181 |
+
def merge_intrinsics(fx, fy, x0, y0):
|
182 |
+
# inputs are shaped []
|
183 |
+
K = np.eye(4)
|
184 |
+
K[0,0] = fx
|
185 |
+
K[1,1] = fy
|
186 |
+
K[0,2] = x0
|
187 |
+
K[1,2] = y0
|
188 |
+
# K is shaped 4 x 4
|
189 |
+
return K
|
190 |
+
|
191 |
+
def scale_intrinsics(K, sx, sy):
|
192 |
+
fx, fy, x0, y0 = split_intrinsics(K)
|
193 |
+
fx *= sx
|
194 |
+
fy *= sy
|
195 |
+
x0 *= sx
|
196 |
+
y0 *= sy
|
197 |
+
return merge_intrinsics(fx, fy, x0, y0)
|
198 |
+
|
199 |
+
# def meshgrid(H, W):
|
200 |
+
# x = np.linspace(0, W-1, W)
|
201 |
+
# y = np.linspace(0, H-1, H)
|
202 |
+
# xv, yv = np.meshgrid(x, y)
|
203 |
+
# return xv, yv
|
204 |
+
|
205 |
+
def compute_distance(transform):
|
206 |
+
"""
|
207 |
+
Compute the distance of the translational component of a 4x4 homogeneous matrix.
|
208 |
+
"""
|
209 |
+
return numpy.linalg.norm(transform[0:3,3])
|
210 |
+
|
211 |
+
def radian_l1_dist(e, g):
|
212 |
+
# if our angles are in [0, 360] we can follow this stack overflow answer:
|
213 |
+
# https://gamedev.stackexchange.com/questions/4467/comparing-angles-and-working-out-the-difference
|
214 |
+
# wrap2pi brings the angles to [-180, 180]; adding pi puts them in [0, 360]
|
215 |
+
e = wrap2pi(e)+np.pi
|
216 |
+
g = wrap2pi(g)+np.pi
|
217 |
+
l = np.abs(np.pi - np.abs(np.abs(e-g) - np.pi))
|
218 |
+
return l
|
219 |
+
|
220 |
+
def apply_pix_T_cam(pix_T_cam, xyz):
|
221 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
222 |
+
# xyz is shaped B x H*W x 3
|
223 |
+
# returns xy, shaped B x H*W x 2
|
224 |
+
N, C = xyz.shape
|
225 |
+
x, y, z = np.split(xyz, 3, axis=-1)
|
226 |
+
EPS = 1e-4
|
227 |
+
z = np.clip(z, EPS, None)
|
228 |
+
x = (x*fx)/(z)+x0
|
229 |
+
y = (y*fy)/(z)+y0
|
230 |
+
xy = np.concatenate([x, y], axis=-1)
|
231 |
+
return xy
|
232 |
+
|
233 |
+
def apply_4x4(RT, XYZ):
|
234 |
+
# RT is 4 x 4
|
235 |
+
# XYZ is N x 3
|
236 |
+
|
237 |
+
# put into homogeneous coords
|
238 |
+
X, Y, Z = np.split(XYZ, 3, axis=1)
|
239 |
+
ones = np.ones_like(X)
|
240 |
+
XYZ1 = np.concatenate([X, Y, Z, ones], axis=1)
|
241 |
+
# XYZ1 is N x 4
|
242 |
+
|
243 |
+
XYZ1_t = np.transpose(XYZ1)
|
244 |
+
# this is 4 x N
|
245 |
+
|
246 |
+
XYZ2_t = np.dot(RT, XYZ1_t)
|
247 |
+
# this is 4 x N
|
248 |
+
|
249 |
+
XYZ2 = np.transpose(XYZ2_t)
|
250 |
+
# this is N x 4
|
251 |
+
|
252 |
+
XYZ2 = XYZ2[:,:3]
|
253 |
+
# this is N x 3
|
254 |
+
|
255 |
+
return XYZ2
|
256 |
+
|
257 |
+
def Ref2Mem(xyz, Z, Y, X):
|
258 |
+
# xyz is N x 3, in ref coordinates
|
259 |
+
# transforms ref coordinates into mem coordinates
|
260 |
+
N, C = xyz.shape
|
261 |
+
assert(C==3)
|
262 |
+
mem_T_ref = get_mem_T_ref(Z, Y, X)
|
263 |
+
xyz = apply_4x4(mem_T_ref, xyz)
|
264 |
+
return xyz
|
265 |
+
|
266 |
+
# def Mem2Ref(xyz_mem, MH, MW, MD):
|
267 |
+
# # xyz is B x N x 3, in mem coordinates
|
268 |
+
# # transforms mem coordinates into ref coordinates
|
269 |
+
# B, N, C = xyz_mem.get_shape().as_list()
|
270 |
+
# ref_T_mem = get_ref_T_mem(B, MH, MW, MD)
|
271 |
+
# xyz_ref = utils_geom.apply_4x4(ref_T_mem, xyz_mem)
|
272 |
+
# return xyz_ref
|
273 |
+
|
274 |
+
def get_mem_T_ref(Z, Y, X):
|
275 |
+
# sometimes we want the mat itself
|
276 |
+
# note this is not a rigid transform
|
277 |
+
|
278 |
+
# for interpretability, let's construct this in two steps...
|
279 |
+
|
280 |
+
# translation
|
281 |
+
center_T_ref = np.eye(4, dtype=np.float32)
|
282 |
+
center_T_ref[0,3] = -XMIN
|
283 |
+
center_T_ref[1,3] = -YMIN
|
284 |
+
center_T_ref[2,3] = -ZMIN
|
285 |
+
|
286 |
+
VOX_SIZE_X = (XMAX-XMIN)/float(X)
|
287 |
+
VOX_SIZE_Y = (YMAX-YMIN)/float(Y)
|
288 |
+
VOX_SIZE_Z = (ZMAX-ZMIN)/float(Z)
|
289 |
+
|
290 |
+
# scaling
|
291 |
+
mem_T_center = np.eye(4, dtype=np.float32)
|
292 |
+
mem_T_center[0,0] = 1./VOX_SIZE_X
|
293 |
+
mem_T_center[1,1] = 1./VOX_SIZE_Y
|
294 |
+
mem_T_center[2,2] = 1./VOX_SIZE_Z
|
295 |
+
|
296 |
+
mem_T_ref = np.dot(mem_T_center, center_T_ref)
|
297 |
+
return mem_T_ref
|
298 |
+
|
299 |
+
def safe_inverse(a):
|
300 |
+
r, t = split_rt(a)
|
301 |
+
t = np.reshape(t, [3, 1])
|
302 |
+
r_transpose = r.T
|
303 |
+
inv = np.concatenate([r_transpose, -np.matmul(r_transpose, t)], 1)
|
304 |
+
bottom_row = a[3:4, :] # this is [0, 0, 0, 1]
|
305 |
+
inv = np.concatenate([inv, bottom_row], 0)
|
306 |
+
return inv
|
307 |
+
|
308 |
+
def get_ref_T_mem(Z, Y, X):
|
309 |
+
mem_T_ref = get_mem_T_ref(X, Y, X)
|
310 |
+
# note safe_inverse is inapplicable here,
|
311 |
+
# since the transform is nonrigid
|
312 |
+
ref_T_mem = np.linalg.inv(mem_T_ref)
|
313 |
+
return ref_T_mem
|
314 |
+
|
315 |
+
def voxelize_xyz(xyz_ref, Z, Y, X):
|
316 |
+
# xyz_ref is N x 3
|
317 |
+
xyz_mem = Ref2Mem(xyz_ref, Z, Y, X)
|
318 |
+
# this is N x 3
|
319 |
+
voxels = get_occupancy(xyz_mem, Z, Y, X)
|
320 |
+
voxels = np.reshape(voxels, [Z, Y, X, 1])
|
321 |
+
return voxels
|
322 |
+
|
323 |
+
def get_inbounds(xyz, Z, Y, X, already_mem=False):
|
324 |
+
# xyz is H*W x 3
|
325 |
+
|
326 |
+
if not already_mem:
|
327 |
+
xyz = Ref2Mem(xyz, Z, Y, X)
|
328 |
+
|
329 |
+
x_valid = np.logical_and(
|
330 |
+
np.greater_equal(xyz[:,0], -0.5),
|
331 |
+
np.less(xyz[:,0], float(X)-0.5))
|
332 |
+
y_valid = np.logical_and(
|
333 |
+
np.greater_equal(xyz[:,1], -0.5),
|
334 |
+
np.less(xyz[:,1], float(Y)-0.5))
|
335 |
+
z_valid = np.logical_and(
|
336 |
+
np.greater_equal(xyz[:,2], -0.5),
|
337 |
+
np.less(xyz[:,2], float(Z)-0.5))
|
338 |
+
inbounds = np.logical_and(np.logical_and(x_valid, y_valid), z_valid)
|
339 |
+
return inbounds
|
340 |
+
|
341 |
+
def sub2ind3d_zyx(depth, height, width, d, h, w):
|
342 |
+
# same as sub2ind3d, but inputs in zyx order
|
343 |
+
# when gathering/scattering with these inds, the tensor should be Z x Y x X
|
344 |
+
return d*height*width + h*width + w
|
345 |
+
|
346 |
+
def sub2ind3d_yxz(height, width, depth, h, w, d):
|
347 |
+
return h*width*depth + w*depth + d
|
348 |
+
|
349 |
+
# def ind2sub(height, width, ind):
|
350 |
+
# # int input
|
351 |
+
# y = int(ind / height)
|
352 |
+
# x = ind % height
|
353 |
+
# return y, x
|
354 |
+
|
355 |
+
def get_occupancy(xyz_mem, Z, Y, X):
|
356 |
+
# xyz_mem is N x 3
|
357 |
+
# we want to fill a voxel tensor with 1's at these inds
|
358 |
+
|
359 |
+
inbounds = get_inbounds(xyz_mem, Z, Y, X, already_mem=True)
|
360 |
+
inds = np.where(inbounds)
|
361 |
+
|
362 |
+
xyz_mem = np.reshape(xyz_mem[inds], [-1, 3])
|
363 |
+
# xyz_mem is N x 3
|
364 |
+
|
365 |
+
# this is more accurate than a cast/floor, but runs into issues when Y==0
|
366 |
+
xyz_mem = np.round(xyz_mem).astype(np.int32)
|
367 |
+
x = xyz_mem[:,0]
|
368 |
+
y = xyz_mem[:,1]
|
369 |
+
z = xyz_mem[:,2]
|
370 |
+
|
371 |
+
voxels = np.zeros([Z, Y, X], np.float32)
|
372 |
+
voxels[z, y, x] = 1.0
|
373 |
+
|
374 |
+
return voxels
|
375 |
+
|
376 |
+
def pixels2camera(x,y,z,fx,fy,x0,y0):
|
377 |
+
# x and y are locations in pixel coordinates, z is a depth image in meters
|
378 |
+
# their shapes are H x W
|
379 |
+
# fx, fy, x0, y0 are scalar camera intrinsics
|
380 |
+
# returns xyz, sized [B,H*W,3]
|
381 |
+
|
382 |
+
H, W = z.shape
|
383 |
+
|
384 |
+
fx = np.reshape(fx, [1,1])
|
385 |
+
fy = np.reshape(fy, [1,1])
|
386 |
+
x0 = np.reshape(x0, [1,1])
|
387 |
+
y0 = np.reshape(y0, [1,1])
|
388 |
+
|
389 |
+
# unproject
|
390 |
+
x = ((z+EPS)/fx)*(x-x0)
|
391 |
+
y = ((z+EPS)/fy)*(y-y0)
|
392 |
+
|
393 |
+
x = np.reshape(x, [-1])
|
394 |
+
y = np.reshape(y, [-1])
|
395 |
+
z = np.reshape(z, [-1])
|
396 |
+
xyz = np.stack([x,y,z], axis=1)
|
397 |
+
return xyz
|
398 |
+
|
399 |
+
def depth2pointcloud(z, pix_T_cam):
|
400 |
+
H = z.shape[0]
|
401 |
+
W = z.shape[1]
|
402 |
+
y, x = meshgrid2d(H, W)
|
403 |
+
z = np.reshape(z, [H, W])
|
404 |
+
|
405 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
406 |
+
xyz = pixels2camera(x, y, z, fx, fy, x0, y0)
|
407 |
+
return xyz
|
408 |
+
|
409 |
+
def meshgrid2d(Y, X):
|
410 |
+
grid_y = np.linspace(0.0, Y-1, Y)
|
411 |
+
grid_y = np.reshape(grid_y, [Y, 1])
|
412 |
+
grid_y = np.tile(grid_y, [1, X])
|
413 |
+
|
414 |
+
grid_x = np.linspace(0.0, X-1, X)
|
415 |
+
grid_x = np.reshape(grid_x, [1, X])
|
416 |
+
grid_x = np.tile(grid_x, [Y, 1])
|
417 |
+
|
418 |
+
# outputs are Y x X
|
419 |
+
return grid_y, grid_x
|
420 |
+
|
421 |
+
def gridcloud3d(Y, X, Z):
|
422 |
+
x_ = np.linspace(0, X-1, X)
|
423 |
+
y_ = np.linspace(0, Y-1, Y)
|
424 |
+
z_ = np.linspace(0, Z-1, Z)
|
425 |
+
y, x, z = np.meshgrid(y_, x_, z_, indexing='ij')
|
426 |
+
x = np.reshape(x, [-1])
|
427 |
+
y = np.reshape(y, [-1])
|
428 |
+
z = np.reshape(z, [-1])
|
429 |
+
xyz = np.stack([x,y,z], axis=1).astype(np.float32)
|
430 |
+
return xyz
|
431 |
+
|
432 |
+
def gridcloud2d(Y, X):
|
433 |
+
x_ = np.linspace(0, X-1, X)
|
434 |
+
y_ = np.linspace(0, Y-1, Y)
|
435 |
+
y, x = np.meshgrid(y_, x_, indexing='ij')
|
436 |
+
x = np.reshape(x, [-1])
|
437 |
+
y = np.reshape(y, [-1])
|
438 |
+
xy = np.stack([x,y], axis=1).astype(np.float32)
|
439 |
+
return xy
|
440 |
+
|
441 |
+
def normalize(im):
|
442 |
+
im = im - np.min(im)
|
443 |
+
im = im / np.max(im)
|
444 |
+
return im
|
445 |
+
|
446 |
+
def wrap2pi(rad_angle):
|
447 |
+
# rad_angle can be any shape
|
448 |
+
# puts the angle into the range [-pi, pi]
|
449 |
+
return np.arctan2(np.sin(rad_angle), np.cos(rad_angle))
|
450 |
+
|
451 |
+
def convert_occ_to_height(occ):
|
452 |
+
Z, Y, X, C = occ.shape
|
453 |
+
assert(C==1)
|
454 |
+
|
455 |
+
height = np.linspace(float(Y), 1.0, Y)
|
456 |
+
height = np.reshape(height, [1, Y, 1, 1])
|
457 |
+
height = np.max(occ*height, axis=1)/float(Y)
|
458 |
+
height = np.reshape(height, [Z, X, C])
|
459 |
+
return height
|
460 |
+
|
461 |
+
def create_depth_image(xy, Z, H, W):
|
462 |
+
|
463 |
+
# turn the xy coordinates into image inds
|
464 |
+
xy = np.round(xy)
|
465 |
+
|
466 |
+
# lidar reports a sphere of measurements
|
467 |
+
# only use the inds that are within the image bounds
|
468 |
+
# also, only use forward-pointing depths (Z > 0)
|
469 |
+
valid = (xy[:,0] < W-1) & (xy[:,1] < H-1) & (xy[:,0] >= 0) & (xy[:,1] >= 0) & (Z[:] > 0)
|
470 |
+
|
471 |
+
# gather these up
|
472 |
+
xy = xy[valid]
|
473 |
+
Z = Z[valid]
|
474 |
+
|
475 |
+
inds = sub2ind(H,W,xy[:,1],xy[:,0])
|
476 |
+
depth = np.zeros((H*W), np.float32)
|
477 |
+
|
478 |
+
for (index, replacement) in zip(inds, Z):
|
479 |
+
depth[index] = replacement
|
480 |
+
depth[np.where(depth == 0.0)] = 70.0
|
481 |
+
depth = np.reshape(depth, [H, W])
|
482 |
+
|
483 |
+
return depth
|
484 |
+
|
485 |
+
def vis_depth(depth, maxdepth=80.0, log_vis=True):
|
486 |
+
depth[depth<=0.0] = maxdepth
|
487 |
+
if log_vis:
|
488 |
+
depth = np.log(depth)
|
489 |
+
depth = np.clip(depth, 0, np.log(maxdepth))
|
490 |
+
else:
|
491 |
+
depth = np.clip(depth, 0, maxdepth)
|
492 |
+
depth = (depth*255.0).astype(np.uint8)
|
493 |
+
return depth
|
494 |
+
|
495 |
+
def preprocess_color(x):
|
496 |
+
return x.astype(np.float32) * 1./255 - 0.5
|
497 |
+
|
498 |
+
def convert_box_to_ref_T_obj(boxes):
|
499 |
+
shape = boxes.shape
|
500 |
+
boxes = boxes.reshape(-1,9)
|
501 |
+
rots = [eul2rotm(rx,ry,rz)
|
502 |
+
for rx,ry,rz in boxes[:,6:]]
|
503 |
+
rots = np.stack(rots,axis=0)
|
504 |
+
trans = boxes[:,:3]
|
505 |
+
ref_T_objs = [merge_rt(rot,tran)
|
506 |
+
for rot,tran in zip(rots,trans)]
|
507 |
+
ref_T_objs = np.stack(ref_T_objs,axis=0)
|
508 |
+
ref_T_objs = ref_T_objs.reshape(shape[:-1]+(4,4))
|
509 |
+
ref_T_objs = ref_T_objs.astype(np.float32)
|
510 |
+
return ref_T_objs
|
511 |
+
|
512 |
+
def get_rot_from_delta(delta, yaw_only=False):
|
513 |
+
dx = delta[:,0]
|
514 |
+
dy = delta[:,1]
|
515 |
+
dz = delta[:,2]
|
516 |
+
|
517 |
+
bot_hyp = np.sqrt(dz**2 + dx**2)
|
518 |
+
# top_hyp = np.sqrt(bot_hyp**2 + dy**2)
|
519 |
+
|
520 |
+
pitch = -np.arctan2(dy, bot_hyp)
|
521 |
+
yaw = np.arctan2(dz, dx)
|
522 |
+
|
523 |
+
if yaw_only:
|
524 |
+
rot = [eul2rotm(0,y,0) for y in yaw]
|
525 |
+
else:
|
526 |
+
rot = [eul2rotm(0,y,p) for (p,y) in zip(pitch,yaw)]
|
527 |
+
|
528 |
+
rot = np.stack(rot)
|
529 |
+
# rot is B x 3 x 3
|
530 |
+
return rot
|
531 |
+
|
532 |
+
def im2col(im, psize):
|
533 |
+
n_channels = 1 if len(im.shape) == 2 else im.shape[0]
|
534 |
+
(n_channels, rows, cols) = (1,) * (3 - len(im.shape)) + im.shape
|
535 |
+
|
536 |
+
im_pad = np.zeros((n_channels,
|
537 |
+
int(math.ceil(1.0 * rows / psize) * psize),
|
538 |
+
int(math.ceil(1.0 * cols / psize) * psize)))
|
539 |
+
im_pad[:, 0:rows, 0:cols] = im
|
540 |
+
|
541 |
+
final = np.zeros((im_pad.shape[1], im_pad.shape[2], n_channels,
|
542 |
+
psize, psize))
|
543 |
+
for c in np.arange(n_channels):
|
544 |
+
for x in np.arange(psize):
|
545 |
+
for y in np.arange(psize):
|
546 |
+
im_shift = np.vstack(
|
547 |
+
(im_pad[c, x:], im_pad[c, :x]))
|
548 |
+
im_shift = np.column_stack(
|
549 |
+
(im_shift[:, y:], im_shift[:, :y]))
|
550 |
+
final[x::psize, y::psize, c] = np.swapaxes(
|
551 |
+
im_shift.reshape(int(im_pad.shape[1] / psize), psize,
|
552 |
+
int(im_pad.shape[2] / psize), psize), 1, 2)
|
553 |
+
|
554 |
+
return np.squeeze(final[0:rows - psize + 1, 0:cols - psize + 1])
|
555 |
+
|
556 |
+
def filter_discontinuities(depth, filter_size=9, thresh=10):
|
557 |
+
H, W = list(depth.shape)
|
558 |
+
|
559 |
+
# Ensure that filter sizes are okay
|
560 |
+
assert filter_size % 2 == 1, "Can only use odd filter sizes."
|
561 |
+
|
562 |
+
# Compute discontinuities
|
563 |
+
offset = int((filter_size - 1) / 2)
|
564 |
+
patches = 1.0 * im2col(depth, filter_size)
|
565 |
+
mids = patches[:, :, offset, offset]
|
566 |
+
mins = np.min(patches, axis=(2, 3))
|
567 |
+
maxes = np.max(patches, axis=(2, 3))
|
568 |
+
|
569 |
+
discont = np.maximum(np.abs(mins - mids),
|
570 |
+
np.abs(maxes - mids))
|
571 |
+
mark = discont > thresh
|
572 |
+
|
573 |
+
# Account for offsets
|
574 |
+
final_mark = np.zeros((H, W), dtype=np.uint16)
|
575 |
+
final_mark[offset:offset + mark.shape[0],
|
576 |
+
offset:offset + mark.shape[1]] = mark
|
577 |
+
|
578 |
+
return depth * (1 - final_mark)
|
579 |
+
|
580 |
+
def argmax2d(tensor):
|
581 |
+
Y, X = list(tensor.shape)
|
582 |
+
# flatten the Tensor along the height and width axes
|
583 |
+
flat_tensor = tensor.reshape(-1)
|
584 |
+
# argmax of the flat tensor
|
585 |
+
argmax = np.argmax(flat_tensor)
|
586 |
+
|
587 |
+
# convert the indices into 2d coordinates
|
588 |
+
argmax_y = argmax // X # row
|
589 |
+
argmax_x = argmax % X # col
|
590 |
+
|
591 |
+
return argmax_y, argmax_x
|
592 |
+
|
593 |
+
def plot_traj_3d(traj):
|
594 |
+
# traj is S x 3
|
595 |
+
|
596 |
+
# print('traj', traj.shape)
|
597 |
+
S, C = list(traj.shape)
|
598 |
+
assert(C==3)
|
599 |
+
|
600 |
+
fig = plt.figure()
|
601 |
+
ax = fig.add_subplot(111, projection='3d')
|
602 |
+
|
603 |
+
colors = [plt.cm.RdYlBu(i) for i in np.linspace(0,1,S)]
|
604 |
+
# print('colors', colors)
|
605 |
+
|
606 |
+
xs = traj[:,0]
|
607 |
+
ys = -traj[:,1]
|
608 |
+
zs = traj[:,2]
|
609 |
+
|
610 |
+
ax.scatter(xs, zs, ys, s=30, c=colors, marker='o', alpha=1.0, edgecolors=(0,0,0))#, color=color_map[n])
|
611 |
+
|
612 |
+
ax.set_xlabel('X')
|
613 |
+
ax.set_ylabel('Z')
|
614 |
+
ax.set_zlabel('Y')
|
615 |
+
|
616 |
+
ax.set_xlim(0,1)
|
617 |
+
ax.set_ylim(0,1) # this is really Z
|
618 |
+
ax.set_zlim(-1,0) # this is really Y
|
619 |
+
|
620 |
+
buf = io.BytesIO()
|
621 |
+
plt.savefig(buf, format='png')
|
622 |
+
buf.seek(0)
|
623 |
+
image = np.array(Image.open(buf)) # H x W x 4
|
624 |
+
image = image[:,:,:3]
|
625 |
+
|
626 |
+
plt.close()
|
627 |
+
return image
|
628 |
+
|
629 |
+
def camera2pixels(xyz, pix_T_cam):
|
630 |
+
# xyz is shaped N x 3
|
631 |
+
# returns xy, shaped N x 2
|
632 |
+
|
633 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
634 |
+
x, y, z = xyz[:,0], xyz[:,1], xyz[:,2]
|
635 |
+
|
636 |
+
EPS = 1e-4
|
637 |
+
z = np.clip(z, EPS, None)
|
638 |
+
x = (x*fx)/z + x0
|
639 |
+
y = (y*fy)/z + y0
|
640 |
+
xy = np.stack([x, y], axis=-1)
|
641 |
+
return xy
|
642 |
+
|
643 |
+
def make_colorwheel():
|
644 |
+
"""
|
645 |
+
Generates a color wheel for optical flow visualization as presented in:
|
646 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
647 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
648 |
+
|
649 |
+
Code follows the original C++ source code of Daniel Scharstein.
|
650 |
+
Code follows the the Matlab source code of Deqing Sun.
|
651 |
+
|
652 |
+
Returns:
|
653 |
+
np.ndarray: Color wheel
|
654 |
+
"""
|
655 |
+
|
656 |
+
RY = 15
|
657 |
+
YG = 6
|
658 |
+
GC = 4
|
659 |
+
CB = 11
|
660 |
+
BM = 13
|
661 |
+
MR = 6
|
662 |
+
|
663 |
+
ncols = RY + YG + GC + CB + BM + MR
|
664 |
+
colorwheel = np.zeros((ncols, 3))
|
665 |
+
col = 0
|
666 |
+
|
667 |
+
# RY
|
668 |
+
colorwheel[0:RY, 0] = 255
|
669 |
+
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
670 |
+
col = col+RY
|
671 |
+
# YG
|
672 |
+
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
673 |
+
colorwheel[col:col+YG, 1] = 255
|
674 |
+
col = col+YG
|
675 |
+
# GC
|
676 |
+
colorwheel[col:col+GC, 1] = 255
|
677 |
+
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
678 |
+
col = col+GC
|
679 |
+
# CB
|
680 |
+
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
681 |
+
colorwheel[col:col+CB, 2] = 255
|
682 |
+
col = col+CB
|
683 |
+
# BM
|
684 |
+
colorwheel[col:col+BM, 2] = 255
|
685 |
+
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
686 |
+
col = col+BM
|
687 |
+
# MR
|
688 |
+
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
689 |
+
colorwheel[col:col+MR, 0] = 255
|
690 |
+
return colorwheel
|
691 |
+
|
692 |
+
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
693 |
+
"""
|
694 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
695 |
+
|
696 |
+
According to the C++ source code of Daniel Scharstein
|
697 |
+
According to the Matlab source code of Deqing Sun
|
698 |
+
|
699 |
+
Args:
|
700 |
+
u (np.ndarray): Input horizontal flow of shape [H,W]
|
701 |
+
v (np.ndarray): Input vertical flow of shape [H,W]
|
702 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
703 |
+
|
704 |
+
Returns:
|
705 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
706 |
+
"""
|
707 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
708 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
709 |
+
ncols = colorwheel.shape[0]
|
710 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
711 |
+
a = np.arctan2(-v, -u)/np.pi
|
712 |
+
fk = (a+1) / 2*(ncols-1)
|
713 |
+
k0 = np.floor(fk).astype(np.int32)
|
714 |
+
k1 = k0 + 1
|
715 |
+
k1[k1 == ncols] = 0
|
716 |
+
f = fk - k0
|
717 |
+
for i in range(colorwheel.shape[1]):
|
718 |
+
tmp = colorwheel[:,i]
|
719 |
+
col0 = tmp[k0] / 255.0
|
720 |
+
col1 = tmp[k1] / 255.0
|
721 |
+
col = (1-f)*col0 + f*col1
|
722 |
+
idx = (rad <= 1)
|
723 |
+
col[idx] = 1 - rad[idx] * (1-col[idx])
|
724 |
+
col[~idx] = col[~idx] * 0.75 # out of range
|
725 |
+
# Note the 2-i => BGR instead of RGB
|
726 |
+
ch_idx = 2-i if convert_to_bgr else i
|
727 |
+
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
728 |
+
return flow_image
|
729 |
+
|
730 |
+
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
731 |
+
"""
|
732 |
+
Expects a two dimensional flow image of shape.
|
733 |
+
|
734 |
+
Args:
|
735 |
+
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
736 |
+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
737 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
738 |
+
|
739 |
+
Returns:
|
740 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
741 |
+
"""
|
742 |
+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
743 |
+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
744 |
+
if clip_flow is not None:
|
745 |
+
flow_uv = np.clip(flow_uv, -clip_flow, clip_flow) / clip_flow
|
746 |
+
# flow_uv = np.clamp(flow, -clip, clip)/clip
|
747 |
+
|
748 |
+
u = flow_uv[:,:,0]
|
749 |
+
v = flow_uv[:,:,1]
|
750 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
751 |
+
rad_max = np.max(rad)
|
752 |
+
epsilon = 1e-5
|
753 |
+
u = u / (rad_max + epsilon)
|
754 |
+
v = v / (rad_max + epsilon)
|
755 |
+
return flow_uv_to_colors(u, v, convert_to_bgr)
|
utils/samp.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import utils.basic
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
6 |
+
r"""Sample a tensor using bilinear interpolation
|
7 |
+
|
8 |
+
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
9 |
+
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
10 |
+
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
11 |
+
convention.
|
12 |
+
|
13 |
+
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
14 |
+
:math:`B` is the batch size, :math:`C` is the number of channels,
|
15 |
+
:math:`H` is the height of the image, and :math:`W` is the width of the
|
16 |
+
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
17 |
+
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
18 |
+
|
19 |
+
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
20 |
+
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
21 |
+
that in this case the order of the components is slightly different
|
22 |
+
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
23 |
+
|
24 |
+
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
25 |
+
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
26 |
+
left-most image pixel :math:`W-1` to the center of the right-most
|
27 |
+
pixel.
|
28 |
+
|
29 |
+
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
30 |
+
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
31 |
+
the left-most pixel :math:`W` to the right edge of the right-most
|
32 |
+
pixel.
|
33 |
+
|
34 |
+
Similar conventions apply to the :math:`y` for the range
|
35 |
+
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
36 |
+
:math:`[0,T-1]` and :math:`[0,T]`.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
input (Tensor): batch of input images.
|
40 |
+
coords (Tensor): batch of coordinates.
|
41 |
+
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
42 |
+
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Tensor: sampled points.
|
46 |
+
"""
|
47 |
+
|
48 |
+
sizes = input.shape[2:]
|
49 |
+
|
50 |
+
assert len(sizes) in [2, 3]
|
51 |
+
|
52 |
+
if len(sizes) == 3:
|
53 |
+
# t x y -> x y t to match dimensions T H W in grid_sample
|
54 |
+
coords = coords[..., [1, 2, 0]]
|
55 |
+
|
56 |
+
if align_corners:
|
57 |
+
coords = coords * torch.tensor(
|
58 |
+
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
coords = coords * torch.tensor(
|
62 |
+
[2 / size for size in reversed(sizes)], device=coords.device
|
63 |
+
)
|
64 |
+
|
65 |
+
coords -= 1
|
66 |
+
|
67 |
+
return F.grid_sample(
|
68 |
+
input, coords, align_corners=align_corners, padding_mode=padding_mode
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
def sample_features4d(input, coords):
|
73 |
+
r"""Sample spatial features
|
74 |
+
|
75 |
+
`sample_features4d(input, coords)` samples the spatial features
|
76 |
+
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
77 |
+
|
78 |
+
The field is sampled at coordinates :attr:`coords` using bilinear
|
79 |
+
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
80 |
+
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
81 |
+
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
82 |
+
|
83 |
+
The output tensor has one feature per point, and has shape :math:`(B,
|
84 |
+
R, C)`.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
input (Tensor): spatial features.
|
88 |
+
coords (Tensor): points.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
Tensor: sampled features.
|
92 |
+
"""
|
93 |
+
|
94 |
+
B, _, _, _ = input.shape
|
95 |
+
|
96 |
+
# B R 2 -> B R 1 2
|
97 |
+
coords = coords.unsqueeze(2)
|
98 |
+
|
99 |
+
# B C R 1
|
100 |
+
feats = bilinear_sampler(input, coords)
|
101 |
+
|
102 |
+
return feats.permute(0, 2, 1, 3).view(
|
103 |
+
B, -1, feats.shape[1] * feats.shape[3]
|
104 |
+
) # B C R 1 -> B R C
|
105 |
+
|
106 |
+
|
107 |
+
def sample_features5d(input, coords):
|
108 |
+
r"""Sample spatio-temporal features
|
109 |
+
|
110 |
+
`sample_features5d(input, coords)` works in the same way as
|
111 |
+
:func:`sample_features4d` but for spatio-temporal features and points:
|
112 |
+
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
|
113 |
+
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
|
114 |
+
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
input (Tensor): spatio-temporal features.
|
118 |
+
coords (Tensor): spatio-temporal points.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Tensor: sampled features.
|
122 |
+
"""
|
123 |
+
|
124 |
+
B, T, _, _, _ = input.shape
|
125 |
+
|
126 |
+
# B T C H W -> B C T H W
|
127 |
+
input = input.permute(0, 2, 1, 3, 4)
|
128 |
+
|
129 |
+
# B R1 R2 3 -> B R1 R2 1 3
|
130 |
+
coords = coords.unsqueeze(3)
|
131 |
+
|
132 |
+
# B C R1 R2 1
|
133 |
+
feats = bilinear_sampler(input, coords)
|
134 |
+
|
135 |
+
return feats.permute(0, 2, 3, 1, 4).view(
|
136 |
+
B, feats.shape[2], feats.shape[3], feats.shape[1]
|
137 |
+
) # B C R1 R2 1 -> B R1 R2 C
|
138 |
+
|
139 |
+
|
140 |
+
def bilinear_sample2d(im, x, y, return_inbounds=False):
|
141 |
+
# x and y are each B, N
|
142 |
+
# output is B, C, N
|
143 |
+
B, C, H, W = list(im.shape)
|
144 |
+
N = list(x.shape)[1]
|
145 |
+
|
146 |
+
x = x.float()
|
147 |
+
y = y.float()
|
148 |
+
H_f = torch.tensor(H, dtype=torch.float32)
|
149 |
+
W_f = torch.tensor(W, dtype=torch.float32)
|
150 |
+
|
151 |
+
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
|
152 |
+
|
153 |
+
max_y = (H_f - 1).int()
|
154 |
+
max_x = (W_f - 1).int()
|
155 |
+
|
156 |
+
x0 = torch.floor(x).int()
|
157 |
+
x1 = x0 + 1
|
158 |
+
y0 = torch.floor(y).int()
|
159 |
+
y1 = y0 + 1
|
160 |
+
|
161 |
+
x0_clip = torch.clamp(x0, 0, max_x)
|
162 |
+
x1_clip = torch.clamp(x1, 0, max_x)
|
163 |
+
y0_clip = torch.clamp(y0, 0, max_y)
|
164 |
+
y1_clip = torch.clamp(y1, 0, max_y)
|
165 |
+
dim2 = W
|
166 |
+
dim1 = W * H
|
167 |
+
|
168 |
+
base = torch.arange(0, B, dtype=torch.int64, device=x.device)*dim1
|
169 |
+
base = torch.reshape(base, [B, 1]).repeat([1, N])
|
170 |
+
|
171 |
+
base_y0 = base + y0_clip * dim2
|
172 |
+
base_y1 = base + y1_clip * dim2
|
173 |
+
|
174 |
+
idx_y0_x0 = base_y0 + x0_clip
|
175 |
+
idx_y0_x1 = base_y0 + x1_clip
|
176 |
+
idx_y1_x0 = base_y1 + x0_clip
|
177 |
+
idx_y1_x1 = base_y1 + x1_clip
|
178 |
+
|
179 |
+
# use the indices to lookup pixels in the flat image
|
180 |
+
# im is B x C x H x W
|
181 |
+
# move C out to last dim
|
182 |
+
im_flat = (im.permute(0, 2, 3, 1)).reshape(B*H*W, C)
|
183 |
+
i_y0_x0 = im_flat[idx_y0_x0.long()]
|
184 |
+
i_y0_x1 = im_flat[idx_y0_x1.long()]
|
185 |
+
i_y1_x0 = im_flat[idx_y1_x0.long()]
|
186 |
+
i_y1_x1 = im_flat[idx_y1_x1.long()]
|
187 |
+
|
188 |
+
# Finally calculate interpolated values.
|
189 |
+
x0_f = x0.float()
|
190 |
+
x1_f = x1.float()
|
191 |
+
y0_f = y0.float()
|
192 |
+
y1_f = y1.float()
|
193 |
+
|
194 |
+
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
|
195 |
+
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
|
196 |
+
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
|
197 |
+
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
|
198 |
+
|
199 |
+
output = w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + \
|
200 |
+
w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
|
201 |
+
# output is B*N x C
|
202 |
+
output = output.view(B, -1, C)
|
203 |
+
output = output.permute(0, 2, 1)
|
204 |
+
# output is B x C x N
|
205 |
+
|
206 |
+
if return_inbounds:
|
207 |
+
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
|
208 |
+
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
|
209 |
+
inbounds = (x_valid & y_valid).float()
|
210 |
+
inbounds = inbounds.reshape(B, N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
|
211 |
+
return output, inbounds
|
212 |
+
|
213 |
+
return output # B, C, N
|
utils/saveload.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
def save(ckpt_dir, module, optimizer, scheduler, global_step, keep_latest=2, model_name='model'):
|
6 |
+
pathlib.Path(ckpt_dir).mkdir(exist_ok=True, parents=True)
|
7 |
+
prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*pth' % model_name))
|
8 |
+
prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
|
9 |
+
if len(prev_ckpts) > keep_latest-1:
|
10 |
+
for f in prev_ckpts[keep_latest-1:]:
|
11 |
+
f.unlink()
|
12 |
+
save_path = '%s/%s-%09d.pth' % (ckpt_dir, model_name, global_step)
|
13 |
+
save_dict = {
|
14 |
+
"model": module.state_dict(),
|
15 |
+
"optimizer": optimizer.state_dict(),
|
16 |
+
"global_step": global_step,
|
17 |
+
}
|
18 |
+
if scheduler is not None:
|
19 |
+
save_dict['scheduler'] = scheduler.state_dict()
|
20 |
+
print(f"saving {save_path}")
|
21 |
+
torch.save(save_dict, save_path)
|
22 |
+
return False
|
23 |
+
|
24 |
+
def load(fabric, ckpt_path, model, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None, strict=True, verbose=True, weights_only=False):
|
25 |
+
if verbose:
|
26 |
+
print('reading ckpt from %s' % ckpt_path)
|
27 |
+
if not os.path.exists(ckpt_path):
|
28 |
+
print('...there is no full checkpoint in %s' % ckpt_path)
|
29 |
+
print('-- note this function no longer appends "saved_checkpoints/" before the ckpt_path --')
|
30 |
+
assert(False)
|
31 |
+
else:
|
32 |
+
if os.path.isfile(ckpt_path):
|
33 |
+
path = ckpt_path
|
34 |
+
print('...found checkpoint %s' % (path))
|
35 |
+
else:
|
36 |
+
prev_ckpts = list(pathlib.Path(ckpt_path).glob('%s-*pth' % model_name))
|
37 |
+
prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
|
38 |
+
if len(prev_ckpts):
|
39 |
+
path = prev_ckpts[0]
|
40 |
+
# e.g., './checkpoints/2Ai4_5e-4_base18_1539/model-000050000.pth'
|
41 |
+
# OR ./whatever.pth
|
42 |
+
step = int(str(path).split('-')[-1].split('.')[0])
|
43 |
+
if verbose:
|
44 |
+
print('...found checkpoint %s; (parsed step %d from path)' % (path, step))
|
45 |
+
else:
|
46 |
+
print('...there is no full checkpoint here!')
|
47 |
+
return 0
|
48 |
+
if fabric is not None:
|
49 |
+
checkpoint = fabric.load(path)
|
50 |
+
else:
|
51 |
+
checkpoint = torch.load(path, weights_only=weights_only)
|
52 |
+
if optimizer is not None:
|
53 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
54 |
+
if scheduler is not None:
|
55 |
+
scheduler.load_state_dict(checkpoint['scheduler'])
|
56 |
+
assert ignore_load is None # not ready yet
|
57 |
+
if 'model' in checkpoint:
|
58 |
+
state_dict = checkpoint['model']
|
59 |
+
else:
|
60 |
+
state_dict = checkpoint
|
61 |
+
model.load_state_dict(state_dict, strict=strict)
|
62 |
+
return step
|
63 |
+
|
64 |
+
|
65 |
+
|