import torch import torch.nn as nn import torch.nn.functional as F import utils.misc import numpy as np from nets.blocks import CNBlockConfig, ConvNeXt, conv1x1, RelUpdateBlock, InputPadder, CorrBlock, BasicEncoder class Net(nn.Module): def __init__( self, seqlen, use_attn=True, use_mixer=False, use_conv=False, use_convb=False, use_basicencoder=False, use_sinmotion=False, use_relmotion=False, use_sinrelmotion=False, use_feats8=False, no_time=False, no_space=False, no_split=False, no_ctx=False, full_split=False, corr_levels=5, corr_radius=4, num_blocks=3, dim=128, hdim=128, init_weights=True, ): super(Net, self).__init__() self.dim = dim self.hdim = hdim self.no_time = no_time self.no_space = no_space self.seqlen = seqlen self.corr_levels = corr_levels self.corr_radius = corr_radius self.corr_channel = self.corr_levels * (self.corr_radius * 2 + 1) ** 2 self.num_blocks = num_blocks self.use_feats8 = use_feats8 self.use_basicencoder = use_basicencoder self.use_sinmotion = use_sinmotion self.use_relmotion = use_relmotion self.use_sinrelmotion = use_sinrelmotion self.no_split = no_split self.no_ctx = no_ctx self.full_split = full_split if use_basicencoder: if self.full_split: self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8) self.cnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8) else: if self.no_split: self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8) else: self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim*2, stride=8) else: block_setting = [ CNBlockConfig(96, 192, 3, True), # 4x CNBlockConfig(192, 384, 3, False), # 8x CNBlockConfig(384, None, 9, False), # 8x ] self.cnn = ConvNeXt(block_setting, stochastic_depth_prob=0.0, init_weights=init_weights) if self.no_split: self.dot_conv = conv1x1(384, dim) else: self.dot_conv = conv1x1(384, dim*2) self.upsample_weight = nn.Sequential( # convex combination of 3x3 patches nn.Conv2d(dim, dim * 2, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(dim * 2, 64 * 9, 1, padding=0) ) self.flow_head = nn.Sequential( nn.Conv2d(dim, 2*dim, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(2*dim, 2, kernel_size=3, padding=1) ) self.visconf_head = nn.Sequential( nn.Conv2d(dim, 2*dim, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(2*dim, 2, kernel_size=3, padding=1) ) if self.use_sinrelmotion: self.pdim = 84 # 32*2 elif self.use_relmotion: self.pdim = 4 elif self.use_sinmotion: self.pdim = 42 else: self.pdim = 2 self.update_block = RelUpdateBlock(self.corr_channel, self.num_blocks, cdim=dim, hdim=hdim, pdim=self.pdim, use_attn=use_attn, use_mixer=use_mixer, use_conv=use_conv, use_convb=use_convb, use_layer_scale=True, no_time=no_time, no_space=no_space, no_ctx=no_ctx) time_line = torch.linspace(0, seqlen-1, seqlen).reshape(1, seqlen, 1) self.register_buffer("time_emb", utils.misc.get_1d_sincos_pos_embed_from_grid(self.dim, time_line[0])) # 1,S,C def fetch_time_embed(self, t, dtype, is_training=False): S = self.time_emb.shape[1] if t == S: return self.time_emb.to(dtype) elif t==1: if is_training: ind = np.random.choice(S) return self.time_emb[:,ind:ind+1].to(dtype) else: return self.time_emb[:,1:2].to(dtype) else: time_emb = self.time_emb.float() time_emb = F.interpolate(time_emb.permute(0, 2, 1), size=t, mode="linear").permute(0, 2, 1) return time_emb.to(dtype) def coords_grid(self, batch, ht, wd, device, dtype): coords = torch.meshgrid(torch.arange(ht, device=device, dtype=dtype), torch.arange(wd, device=device, dtype=dtype), indexing='ij') coords = torch.stack(coords[::-1], dim=0) return coords[None].repeat(batch, 1, 1, 1) def initialize_flow(self, img): """ Flow is represented as difference between two coordinate grids flow = coords2 - coords1""" N, C, H, W = img.shape coords1 = self.coords_grid(N, H//8, W//8, device=img.device) coords2 = self.coords_grid(N, H//8, W//8, device=img.device) return coords1, coords2 def upsample_data(self, flow, mask): """ Upsample [H/8, W/8, C] -> [H, W, C] using convex combination """ N, C, H, W = flow.shape mask = mask.view(N, 1, 9, 8, 8, H, W) mask = torch.softmax(mask, dim=2) up_flow = F.unfold(8 * flow, [3,3], padding=1) up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) up_flow = torch.sum(mask * up_flow, dim=2) up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) return up_flow.reshape(N, 2, 8*H, 8*W).to(flow.dtype) def get_T_padded_images(self, images, T, S, is_training, stride=None, pad=True): B,T,C,H,W = images.shape indices = None if T > 2: step = S // 2 if stride is None else stride indices = [] start = 0 while start + S < T: indices.append(start) start += step indices.append(start) Tpad = indices[-1]+S-T if pad: if is_training: assert Tpad == 0 else: images = images.reshape(B,1,T,C*H*W) if Tpad > 0: padding_tensor = images[:,:,-1:,:].expand(B,1,Tpad,C*H*W) images = torch.cat([images, padding_tensor], dim=2) images = images.reshape(B,T+Tpad,C,H,W) T = T+Tpad else: assert T == 2 return images, T, indices def get_fmaps(self, images_, B, T, sw, is_training): _, _, H_pad, W_pad = images_.shape # revised HW C, H8, W8 = self.dim*2, H_pad//8, W_pad//8 if self.no_split: C = self.dim fmaps_chunk_size = 32 if (not is_training) and (T > fmaps_chunk_size): images = images_.reshape(B,T,3,H_pad,W_pad) fmaps = [] for t in range(0, T, fmaps_chunk_size): images_chunk = images[:, t : t + fmaps_chunk_size] images_chunk = images_chunk.cuda() if self.use_basicencoder: if self.full_split: fmaps_chunk1 = self.fnet(images_chunk.reshape(-1, 3, H_pad, W_pad)) fmaps_chunk2 = self.cnet(images_chunk.reshape(-1, 3, H_pad, W_pad)) fmaps_chunk = torch.cat([fmaps_chunk1, fmaps_chunk2], axis=1) else: fmaps_chunk = self.fnet(images_chunk.reshape(-1, 3, H_pad, W_pad)) else: fmaps_chunk = self.cnn(images_chunk.reshape(-1, 3, H_pad, W_pad)) if t==0 and sw is not None and sw.save_this: sw.summ_feat('1_model/fmap_raw', fmaps_chunk[0:1]) fmaps_chunk = self.dot_conv(fmaps_chunk) # B*T,C,H8,W8 T_chunk = images_chunk.shape[1] fmaps.append(fmaps_chunk.reshape(B, -1, C, H8, W8)) fmaps_ = torch.cat(fmaps, dim=1).reshape(-1, C, H8, W8) else: if not is_training: # sometimes we need to move things to cuda here images_ = images_.cuda() if self.use_basicencoder: if self.full_split: fmaps1_ = self.fnet(images_) fmaps2_ = self.cnet(images_) fmaps_ = torch.cat([fmaps1_, fmaps2_], axis=1) else: fmaps_ = self.fnet(images_) else: fmaps_ = self.cnn(images_) if sw is not None and sw.save_this: sw.summ_feat('1_model/fmap_raw', fmaps_[0:1]) fmaps_ = self.dot_conv(fmaps_) # B*T,C,H8,W8 return fmaps_ def forward(self, images, iters=4, sw=None, is_training=False, stride=None): B,T,C,H,W = images.shape S = self.seqlen device = images.device dtype = images.dtype print('images', images.shape) # images are in [0,255] mean = torch.as_tensor([0.485, 0.456, 0.406], device=device).reshape(1,1,3,1,1).to(images.dtype) std = torch.as_tensor([0.229, 0.224, 0.225], device=device).reshape(1,1,3,1,1).to(images.dtype) images = images / 255.0 images = (images - mean)/std # print("a0 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) T_bak = T if stride is not None: pad = False else: pad = True images, T, indices = self.get_T_padded_images(images, T, S, is_training, stride=stride, pad=pad) images = images.contiguous() images_ = images.reshape(B*T,3,H,W) padder = InputPadder(images_.shape) images_ = padder.pad(images_)[0] # print("a1 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) _, _, H_pad, W_pad = images_.shape # revised HW C, H8, W8 = self.dim*2, H_pad//8, W_pad//8 C2 = C//2 if self.no_split: C = self.dim C2 = C fmaps = self.get_fmaps(images_, B, T, sw, is_training).reshape(B,T,C,H8,W8) device = fmaps.device # print("a2 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) fmap_anchor = fmaps[:,0] if T<=2 or is_training: # note: collecting preds can get expensive on a long video all_flow_preds = [] all_visconf_preds = [] else: all_flow_preds = None all_visconf_preds = None if T > 2: # multiframe tracking # we will store our final outputs in these tensors full_flows = torch.zeros((B,T,2,H,W), dtype=dtype, device=device) full_visconfs = torch.zeros((B,T,2,H,W), dtype=dtype, device=device) # 1/8 resolution full_flows8 = torch.zeros((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device) full_visconfs8 = torch.zeros((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device) if self.use_feats8: full_feats8 = torch.zeros((B,T,C2,H_pad//8,W_pad//8), dtype=dtype, device=device) visits = np.zeros((T)) # print("a3 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) for ii, ind in enumerate(indices): ara = np.arange(ind,ind+S) # print('ara', ara) if ii < len(indices)-1: next_ind = indices[ii+1] next_ara = np.arange(next_ind,next_ind+S) # print("torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024), 'ara', ara) fmaps2 = fmaps[:,ara] flows8 = full_flows8[:,ara].reshape(B*(S),2,H_pad//8,W_pad//8).detach() visconfs8 = full_visconfs8[:,ara].reshape(B*(S),2,H_pad//8,W_pad//8).detach() if self.use_feats8: if ind==0: feats8 = None else: feats8 = full_feats8[:,ara].reshape(B*(S),C2,H_pad//8,W_pad//8).detach() else: feats8 = None # print("a4 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window( fmap_anchor, fmaps2, visconfs8, iters=iters, flowfeat=feats8, flows8=flows8, is_training=is_training) # print("a5 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) unpad_flow_predictions = [] unpad_visconf_predictions = [] for i in range(len(flow_predictions)): flow_predictions[i] = padder.unpad(flow_predictions[i]) unpad_flow_predictions.append(flow_predictions[i].reshape(B,S,2,H,W)) visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i])) unpad_visconf_predictions.append(visconf_predictions[i].reshape(B,S,2,H,W)) # print("a6 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) full_flows[:,ara] = unpad_flow_predictions[-1].reshape(B,S,2,H,W) full_flows8[:,ara] = flows8.reshape(B,S,2,H_pad//8,W_pad//8) full_visconfs[:,ara] = unpad_visconf_predictions[-1].reshape(B,S,2,H,W) full_visconfs8[:,ara] = visconfs8.reshape(B,S,2,H_pad//8,W_pad//8) if self.use_feats8: full_feats8[:,ara] = feats8.reshape(B,S,C2,H_pad//8,W_pad//8) visits[ara] += 1 # print("a7 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) if is_training: all_flow_preds.append(unpad_flow_predictions) all_visconf_preds.append(unpad_visconf_predictions) else: del unpad_flow_predictions del unpad_visconf_predictions # for the next iter, replace empty data with nearest available preds invalid_idx = np.where(visits==0)[0] valid_idx = np.where(visits>0)[0] for idx in invalid_idx: nearest = valid_idx[np.argmin(np.abs(valid_idx - idx))] # print('replacing %d with %d' % (idx, nearest)) full_flows8[:,idx] = full_flows8[:,nearest] full_visconfs8[:,idx] = full_visconfs8[:,nearest] if self.use_feats8: full_feats8[:,idx] = full_feats8[:,nearest] # print("a8 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) else: # flow flows8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device) visconfs8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device) flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window( fmap_anchor, fmaps[:,1:2], visconfs8, iters=iters, flowfeat=None, flows8=flows8, is_training=is_training) unpad_flow_predictions = [] unpad_visconf_predictions = [] for i in range(len(flow_predictions)): flow_predictions[i] = padder.unpad(flow_predictions[i]) all_flow_preds.append(flow_predictions[i].reshape(B,2,H,W)) visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i])) all_visconf_preds.append(visconf_predictions[i].reshape(B,2,H,W)) full_flows = all_flow_preds[-1].reshape(B,2,H,W) full_visconfs = all_visconf_preds[-1].reshape(B,2,H,W) if (not is_training) and (T > 2): full_flows = full_flows[:,:T_bak] full_visconfs = full_visconfs[:,:T_bak] # print("a9 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) return full_flows, full_visconfs, all_flow_preds, all_visconf_preds def forward_sliding(self, images, iters=4, sw=None, is_training=False, window_len=None, stride=None): B,T,C,H,W = images.shape S = self.seqlen if window_len is None else window_len device = images.device dtype = images.dtype stride = S // 2 if stride is None else stride T_bak = T images, T, indices = self.get_T_padded_images(images, T, S, is_training, stride) assert stride <= S // 2 images = images.contiguous() images_ = images.reshape(B*T,3,H,W) padder = InputPadder(images_.shape) images_ = padder.pad(images_)[0] _, _, H_pad, W_pad = images_.shape # revised HW C, H8, W8 = self.dim*2, H_pad//8, W_pad//8 C2 = C//2 if self.no_split: C = self.dim C2 = C all_flow_preds = None all_visconf_preds = None if T<=2: # note: collecting preds can get expensive on a long video all_flow_preds = [] all_visconf_preds = [] fmaps = self.get_fmaps(images_, B, T, sw, is_training).reshape(B,T,C,H8,W8) device = fmaps.device flows8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device) visconfs8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device) fmap_anchor = fmaps[:,0] flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window( fmap_anchor, fmaps[:,1:2], visconfs8, iters=iters, flowfeat=None, flows8=flows8, is_training=is_training) unpad_flow_predictions = [] unpad_visconf_predictions = [] for i in range(len(flow_predictions)): flow_predictions[i] = padder.unpad(flow_predictions[i]) all_flow_preds.append(flow_predictions[i].reshape(B,2,H,W)) visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i])) all_visconf_preds.append(visconf_predictions[i].reshape(B,2,H,W)) full_flows = all_flow_preds[-1].reshape(B,2,H,W).detach().cpu() full_visconfs = all_visconf_preds[-1].reshape(B,2,H,W).detach().cpu() return full_flows, full_visconfs, all_flow_preds, all_visconf_preds assert T > 2 # multiframe tracking if is_training: all_flow_preds = [] all_visconf_preds = [] # we will store our final outputs in these cpu tensors full_flows = torch.zeros((B,T,2,H,W), dtype=dtype, device='cpu') full_visconfs = torch.zeros((B,T,2,H,W), dtype=dtype, device='cpu') images_ = images_.reshape(B,T,3,H_pad,W_pad) fmap_anchor = self.get_fmaps(images_[:,:1].reshape(-1,3,H_pad,W_pad), B, 1, sw, is_training).reshape(B,C,H8,W8) device = fmap_anchor.device full_visited = torch.zeros((T,), dtype=torch.bool, device=device) for ii, ind in enumerate(indices): ara = np.arange(ind,ind+S) if ii == 0: flows8 = torch.zeros((B,S,2,H_pad//8,W_pad//8), dtype=dtype, device=device) visconfs8 = torch.zeros((B,S,2,H_pad//8,W_pad//8), dtype=dtype, device=device) fmaps2 = self.get_fmaps(images_[:,ara].reshape(-1,3,H_pad,W_pad), B, S, sw, is_training).reshape(B,S,C,H8,W8) else: flows8 = torch.cat([flows8[:,stride:stride+S//2], flows8[:,stride+S//2-1:stride+S//2].repeat(1,S//2,1,1,1)], dim=1) visconfs8 = torch.cat([visconfs8[:,stride:stride+S//2], visconfs8[:,stride+S//2-1:stride+S//2].repeat(1,S//2,1,1,1)], dim=1) fmaps2 = torch.cat([fmaps2[:,stride:stride+S//2], self.get_fmaps(images_[:,np.arange(ind+S//2,ind+S)].reshape(-1,3,H_pad,W_pad), B, S//2, sw, is_training).reshape(B,S//2,C,H8,W8)], dim=1) flows8 = flows8.reshape(B*S,2,H_pad//8,W_pad//8).detach() visconfs8 = visconfs8.reshape(B*S,2,H_pad//8,W_pad//8).detach() flow_predictions, visconf_predictions, flows8, visconfs8, _ = self.forward_window( fmap_anchor, fmaps2, visconfs8, iters=iters, flowfeat=None, flows8=flows8, is_training=is_training) unpad_flow_predictions = [] unpad_visconf_predictions = [] for i in range(len(flow_predictions)): flow_predictions[i] = padder.unpad(flow_predictions[i]) unpad_flow_predictions.append(flow_predictions[i].reshape(B,S,2,H,W)) visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i])) unpad_visconf_predictions.append(visconf_predictions[i].reshape(B,S,2,H,W)) current_visiting = torch.zeros((T,), dtype=torch.bool, device=device) current_visiting[ara] = True to_fill = current_visiting & (~full_visited) to_fill_sum = to_fill.sum().item() full_flows[:,to_fill] = unpad_flow_predictions[-1].reshape(B,S,2,H,W)[:,-to_fill_sum:].detach().cpu() full_visconfs[:,to_fill] = unpad_visconf_predictions[-1].reshape(B,S,2,H,W)[:,-to_fill_sum:].detach().cpu() full_visited |= current_visiting if is_training: all_flow_preds.append(unpad_flow_predictions) all_visconf_preds.append(unpad_visconf_predictions) else: del unpad_flow_predictions del unpad_visconf_predictions flows8 = flows8.reshape(B,S,2,H_pad//8,W_pad//8) visconfs8 = visconfs8.reshape(B,S,2,H_pad//8,W_pad//8) if not is_training: full_flows = full_flows[:,:T_bak] full_visconfs = full_visconfs[:,:T_bak] return full_flows, full_visconfs, all_flow_preds, all_visconf_preds def forward_window(self, fmap1_single, fmaps2, visconfs8, iters=None, flowfeat=None, flows8=None, sw=None, is_training=False): B,S,C,H8,W8 = fmaps2.shape device = fmaps2.device dtype = fmaps2.dtype flow_predictions = [] visconf_predictions = [] fmap1 = fmap1_single.unsqueeze(1).repeat(1,S,1,1,1) # B,S,C,H,W fmap1 = fmap1.reshape(B*(S),C,H8,W8).contiguous() fmap2 = fmaps2.reshape(B*(S),C,H8,W8).contiguous() visconfs8 = visconfs8.reshape(B*(S),2,H8,W8).contiguous() corr_fn = CorrBlock(fmap1, fmap2, self.corr_levels, self.corr_radius) coords1 = self.coords_grid(B*(S), H8, W8, device=fmap1.device, dtype=dtype) if self.no_split: flowfeat, ctxfeat = fmap1.clone(), fmap1.clone() else: if flowfeat is not None: _, ctxfeat = torch.split(fmap1, [self.dim, self.dim], dim=1) else: flowfeat, ctxfeat = torch.split(fmap1, [self.dim, self.dim], dim=1) # add pos emb to ctxfeat (and not flowfeat), since ctxfeat is untouched across iters time_emb = self.fetch_time_embed(S, ctxfeat.dtype, is_training).reshape(1,S,self.dim,1,1).repeat(B,1,1,1,1) ctxfeat = ctxfeat + time_emb.reshape(B*S,self.dim,1,1) if self.no_ctx: flowfeat = flowfeat + time_emb.reshape(B*S,self.dim,1,1) for itr in range(iters): _, _, H8, W8 = flows8.shape flows8 = flows8.detach() coords2 = (coords1 + flows8).detach() # B*S,2,H,W corr = corr_fn(coords2).to(dtype) if self.use_relmotion or self.use_sinrelmotion: coords_ = coords2.reshape(B,S,2,H8*W8).permute(0,1,3,2) # B,S,H8*W8,2 rel_coords_forward = coords_[:, :-1] - coords_[:, 1:] rel_coords_backward = coords_[:, 1:] - coords_[:, :-1] rel_coords_forward = torch.nn.functional.pad( rel_coords_forward, (0, 0, 0, 0, 0, 1) # pad the 3rd-last dim (S) by (0,1) ) rel_coords_backward = torch.nn.functional.pad( rel_coords_backward, (0, 0, 0, 0, 1, 0) # pad the 3rd-last dim (S) by (1,0) ) rel_coords = torch.cat([rel_coords_forward, rel_coords_backward], dim=-1) # B,S,H8*W8,4 if self.use_sinrelmotion: rel_pos_emb_input = utils.misc.posenc( rel_coords, min_deg=0, max_deg=10, ) # B,S,H*W,pdim motion = rel_pos_emb_input.reshape(B*S,H8,W8,self.pdim).permute(0,3,1,2).to(dtype) # B*S,pdim,H8,W8 else: motion = rel_coords.reshape(B*S,H8,W8,4).permute(0,3,1,2).to(dtype) # B*S,4,H8,W8 else: if self.use_sinmotion: pos_emb_input = utils.misc.posenc( flows8.reshape(B,S,H8*W8,2), min_deg=0, max_deg=10, ) # B,S,H*W,pdim motion = pos_emb_input.reshape(B*S,H8,W8,self.pdim).permute(0,3,1,2).to(dtype) # B*S,pdim,H8,W8 else: motion = flows8 flowfeat = self.update_block(flowfeat, ctxfeat, visconfs8, corr, motion, S) flow_update = self.flow_head(flowfeat) visconf_update = self.visconf_head(flowfeat) weight_update = .25 * self.upsample_weight(flowfeat) flows8 = flows8 + flow_update visconfs8 = visconfs8 + visconf_update flow_up = self.upsample_data(flows8, weight_update) visconf_up = self.upsample_data(visconfs8, weight_update) if not is_training: # clear mem flow_predictions = [] visconf_predictions = [] flow_predictions.append(flow_up) visconf_predictions.append(visconf_up) return flow_predictions, visconf_predictions, flows8, visconfs8, flowfeat