import torch import torch.nn as nn import torch.nn.functional as F from torch import nn, Tensor from itertools import repeat import collections from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence from functools import partial import einops import math from torchvision.ops.misc import Conv2dNormActivation, Permute from torchvision.ops.stochastic_depth import StochasticDepth def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse def exists(val): return val is not None def default(val, d): return val if exists(val) else d to_2tuple = _ntuple(2) class InputPadder: """ Pads images such that dimensions are divisible by a certain stride """ def __init__(self, dims, mode='sintel'): self.ht, self.wd = dims[-2:] pad_ht = (((self.ht // 64) + 1) * 64 - self.ht) % 64 pad_wd = (((self.wd // 64) + 1) * 64 - self.wd) % 64 if mode == 'sintel': self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] else: self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] def pad(self, *inputs): return [F.pad(x, self._pad, mode='replicate') for x in inputs] def unpad(self, x): ht, wd = x.shape[-2:] c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] return x[..., c[0]:c[1], c[2]:c[3]] def bilinear_sampler( input, coords, align_corners=True, padding_mode="border", normalize_coords=True): # func from mattie (oct9) if input.ndim not in [4, 5]: raise ValueError("input must be 4D or 5D.") if input.ndim == 4 and not coords.ndim == 4: raise ValueError("input is 4D, but coords is not 4D.") if input.ndim == 5 and not coords.ndim == 5: raise ValueError("input is 5D, but coords is not 5D.") if coords.ndim == 5: coords = coords[..., [1, 2, 0]] # t x y -> x y t to match what grid_sample() expects. if normalize_coords: if align_corners: # Normalize coordinates from [0, W/H - 1] to [-1, 1]. coords = ( coords * torch.tensor([2 / max(size - 1, 1) for size in reversed(input.shape[2:])], device=coords.device) - 1 ) else: # Normalize coordinates from [0, W/H] to [-1, 1]. coords = coords * torch.tensor([2 / size for size in reversed(input.shape[2:])], device=coords.device) - 1 return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) class CorrBlock: def __init__(self, fmap1, fmap2, corr_levels, corr_radius): self.num_levels = corr_levels self.radius = corr_radius self.corr_pyramid = [] # all pairs correlation for i in range(self.num_levels): corr = CorrBlock.corr(fmap1, fmap2, 1) batch, h1, w1, dim, h2, w2 = corr.shape corr = corr.reshape(batch*h1*w1, dim, h2, w2) fmap2 = F.interpolate(fmap2, scale_factor=0.5, mode='area') # print('corr', corr.shape) self.corr_pyramid.append(corr) def __call__(self, coords, dilation=None): r = self.radius coords = coords.permute(0, 2, 3, 1) batch, h1, w1, _ = coords.shape if dilation is None: dilation = torch.ones(batch, 1, h1, w1, device=coords.device) out_pyramid = [] for i in range(self.num_levels): corr = self.corr_pyramid[i] device = coords.device dx = torch.linspace(-r, r, 2*r+1, device=device) dy = torch.linspace(-r, r, 2*r+1, device=device) delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) delta_lvl = delta_lvl * dilation.view(batch * h1 * w1, 1, 1, 1) centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i coords_lvl = centroid_lvl + delta_lvl corr = bilinear_sampler(corr, coords_lvl) corr = corr.view(batch, h1, w1, -1) out_pyramid.append(corr) out = torch.cat(out_pyramid, dim=-1) out = out.permute(0, 3, 1, 2).contiguous().float() return out @staticmethod def corr(fmap1, fmap2, num_head): batch, dim, h1, w1 = fmap1.shape h2, w2 = fmap2.shape[2:] fmap1 = fmap1.view(batch, num_head, dim // num_head, h1*w1) fmap2 = fmap2.view(batch, num_head, dim // num_head, h2*w2) corr = fmap1.transpose(2, 3) @ fmap2 corr = corr.reshape(batch, num_head, h1, w1, h2, w2).permute(0, 2, 3, 1, 4, 5) return corr / torch.sqrt(torch.tensor(dim).float()) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution without padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) class LayerNorm2d(nn.LayerNorm): def forward(self, x: Tensor) -> Tensor: x = x.permute(0, 2, 3, 1) x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2) return x class CNBlock1d(nn.Module): def __init__( self, dim, output_dim, layer_scale: float = 1e-6, stochastic_depth_prob: float = 0, norm_layer: Optional[Callable[..., nn.Module]] = None, dense=True, use_attn=True, use_mixer=False, use_conv=False, use_convb=False, use_layer_scale=True, ) -> None: super().__init__() self.dense = dense self.use_attn = use_attn self.use_mixer = use_mixer self.use_conv = use_conv self.use_layer_scale = use_layer_scale if use_attn: assert not use_mixer assert not use_conv assert not use_convb if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) if use_attn: num_heads = 8 self.block = AttnBlock( hidden_size=dim, num_heads=num_heads, mlp_ratio=4, attn_class=Attention, ) elif use_mixer: self.block = MLPMixerBlock( S=16, dim=dim, depth=1, expansion_factor=2, ) elif use_conv: self.block = nn.Sequential( nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'), Permute([0, 2, 1]), norm_layer(dim), nn.Linear(in_features=dim, out_features=4 * dim, bias=True), nn.GELU(), nn.Linear(in_features=4 * dim, out_features=dim, bias=True), Permute([0, 2, 1]), ) elif use_convb: self.block = nn.Sequential( nn.Conv1d(dim, dim, kernel_size=3, padding=1, bias=True, padding_mode='zeros'), Permute([0, 2, 1]), norm_layer(dim), nn.Linear(in_features=dim, out_features=4 * dim, bias=True), nn.GELU(), nn.Linear(in_features=4 * dim, out_features=dim, bias=True), Permute([0, 2, 1]), ) else: assert(False) # choose attn, mixer, or conv please if self.use_layer_scale: self.layer_scale = nn.Parameter(torch.ones(dim, 1) * layer_scale) else: self.layer_scale = 1.0 self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") if output_dim != dim: self.final = nn.Conv1d(dim, output_dim, kernel_size=1, padding=0) else: self.final = nn.Identity() def forward(self, input, S=None): if self.dense: assert S is not None BS,C,H,W = input.shape B = BS//S input = einops.rearrange(input, '(b s) c h w -> (b h w) c s', b=B, s=S, c=C, h=H, w=W) if self.use_mixer or self.use_attn: # mixer/transformer blocks want B,S,C result = self.layer_scale * self.block(input.permute(0,2,1)).permute(0,2,1) else: result = self.layer_scale * self.block(input) result = self.stochastic_depth(result) result += input result = self.final(result) result = einops.rearrange(result, '(b h w) c s -> (b s) c h w', b=B, s=S, c=C, h=H, w=W) else: B,S,C = input.shape if S<7: return input input = einops.rearrange(input, 'b s c -> b c s', b=B, s=S, c=C) result = self.layer_scale * self.block(input) result = self.stochastic_depth(result) result += input result = self.final(result) result = einops.rearrange(result, 'b c s -> b s c', b=B, s=S, c=C) return result class CNBlock2d(nn.Module): def __init__( self, dim, output_dim, layer_scale: float = 1e-6, stochastic_depth_prob: float = 0, norm_layer: Optional[Callable[..., nn.Module]] = None, use_layer_scale=True, ) -> None: super().__init__() self.use_layer_scale = use_layer_scale if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.block = nn.Sequential( nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'), Permute([0, 2, 3, 1]), norm_layer(dim), nn.Linear(in_features=dim, out_features=4 * dim, bias=True), nn.GELU(), nn.Linear(in_features=4 * dim, out_features=dim, bias=True), Permute([0, 3, 1, 2]), ) if self.use_layer_scale: self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) else: self.layer_scale = 1.0 self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") if output_dim != dim: self.final = nn.Conv2d(dim, output_dim, kernel_size=1, padding=0) else: self.final = nn.Identity() def forward(self, input, S=None): result = self.layer_scale * self.block(input) result = self.stochastic_depth(result) result += input result = self.final(result) return result class CNBlockConfig: # Stores information listed at Section 3 of the ConvNeXt paper def __init__( self, input_channels: int, out_channels: Optional[int], num_layers: int, downsample: bool, ) -> None: self.input_channels = input_channels self.out_channels = out_channels self.num_layers = num_layers self.downsample = downsample def __repr__(self) -> str: s = self.__class__.__name__ + "(" s += "input_channels={input_channels}" s += ", out_channels={out_channels}" s += ", num_layers={num_layers}" s += ", downsample={downsample}" s += ")" return s.format(**self.__dict__) class ConvNeXt(nn.Module): def __init__( self, block_setting: List[CNBlockConfig], stochastic_depth_prob: float = 0.0, layer_scale: float = 1e-6, num_classes: int = 1000, block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, init_weights=True): super().__init__() self.init_weights = init_weights if not block_setting: raise ValueError("The block_setting should not be empty") elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])): raise TypeError("The block_setting should be List[CNBlockConfig]") if block is None: block = CNBlock2d if norm_layer is None: norm_layer = partial(LayerNorm2d, eps=1e-6) layers: List[nn.Module] = [] # Stem firstconv_output_channels = block_setting[0].input_channels layers.append( Conv2dNormActivation( 3, firstconv_output_channels, kernel_size=4, stride=4, padding=0, norm_layer=norm_layer, activation_layer=None, bias=True, ) ) total_stage_blocks = sum(cnf.num_layers for cnf in block_setting) stage_block_id = 0 for cnf in block_setting: # Bottlenecks stage: List[nn.Module] = [] for _ in range(cnf.num_layers): # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) stage.append(block(cnf.input_channels, cnf.input_channels, layer_scale, sd_prob)) stage_block_id += 1 layers.append(nn.Sequential(*stage)) if cnf.out_channels is not None: if cnf.downsample: layers.append( nn.Sequential( norm_layer(cnf.input_channels), nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2), ) ) else: # we convert the 2x2 downsampling layer into a 3x3 with dilation2 and replicate padding. # replicate padding compensates for the fact that this kernel never saw zero-padding. layers.append( nn.Sequential( norm_layer(cnf.input_channels), nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=3, stride=1, padding=2, dilation=2, padding_mode='zeros'), ) ) self.features = nn.Sequential(*layers) # self.final_conv = conv1x1(block_setting[-1].input_channels, output_dim) for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) if self.init_weights: from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights pretrained_dict = convnext_tiny(weights=ConvNeXt_Tiny_Weights.DEFAULT).state_dict() # from torchvision.models import convnext_base, ConvNeXt_Base_Weights # pretrained_dict = convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT).state_dict() model_dict = self.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} for k, v in pretrained_dict.items(): if k == 'features.4.1.weight': # this is the layer normally in charge of 2x2 downsampling # convert to 3x3 filter pretrained_dict[k] = F.interpolate(v, (3, 3), mode='bicubic', align_corners=True) * (4/9.0) model_dict.update(pretrained_dict) self.load_state_dict(model_dict, strict=False) def _forward_impl(self, x: Tensor) -> Tensor: x = self.features(x) # x = self.final_conv(x) return x def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) class Mlp(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0.0, use_conv=False, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.norm = ( norm_layer(hidden_features) if norm_layer is not None else nn.Identity() ) self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class Attention(nn.Module): def __init__( self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False ): super().__init__() inner_dim = dim_head * num_heads context_dim = default(context_dim, query_dim) self.scale = dim_head**-0.5 self.heads = num_heads self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias) self.to_out = nn.Linear(inner_dim, query_dim) def forward(self, x, context=None, attn_bias=None): B, N1, C = x.shape H = self.heads q = self.to_q(x) context = default(context, x) k, v = self.to_kv(context).chunk(2, dim=-1) q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v)) x = F.scaled_dot_product_attention(q, k, v) # scale default is already dim^-0.5 x = einops.rearrange(x, 'b h n d -> b n (h d)') return self.to_out(x) class CrossAttnBlock(nn.Module): def __init__( self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs ): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.norm_context = nn.LayerNorm(hidden_size) self.cross_attn = Attention( hidden_size, context_dim=context_dim, num_heads=num_heads, qkv_bias=True, **block_kwargs ) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp( in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, ) def forward(self, x, context, mask=None): attn_bias = None if mask is not None: if mask.shape[1] == x.shape[1]: mask = mask[:, None, :, None].expand( -1, self.cross_attn.heads, -1, context.shape[1] ) else: mask = mask[:, None, None].expand( -1, self.cross_attn.heads, x.shape[1], -1 ) max_neg_value = -torch.finfo(x.dtype).max attn_bias = (~mask) * max_neg_value x = x + self.cross_attn( self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias ) x = x + self.mlp(self.norm2(x)) return x class AttnBlock(nn.Module): def __init__( self, hidden_size, num_heads, attn_class: Callable[..., nn.Module] = Attention, mlp_ratio=4.0, **block_kwargs ): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, dim_head=hidden_size//num_heads) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp( in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, ) def forward(self, x, mask=None): attn_bias = mask if mask is not None: mask = ( (mask[:, None] * mask[:, :, None]) .unsqueeze(1) .expand(-1, self.attn.num_heads, -1, -1) ) max_neg_value = -torch.finfo(x.dtype).max attn_bias = (~mask) * max_neg_value x = x + self.attn(self.norm1(x), attn_bias=attn_bias) x = x + self.mlp(self.norm2(x)) return x class ResidualBlock(nn.Module): def __init__(self, in_planes, planes, norm_fn="group", stride=1): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, padding=1, stride=stride, padding_mode="zeros", ) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, padding=1, padding_mode="zeros" ) self.relu = nn.ReLU(inplace=True) num_groups = planes // 8 if norm_fn == "group": self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) if not stride == 1: self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) elif norm_fn == "batch": self.norm1 = nn.BatchNorm2d(planes) self.norm2 = nn.BatchNorm2d(planes) if not stride == 1: self.norm3 = nn.BatchNorm2d(planes) elif norm_fn == "instance": self.norm1 = nn.InstanceNorm2d(planes) self.norm2 = nn.InstanceNorm2d(planes) if not stride == 1: self.norm3 = nn.InstanceNorm2d(planes) elif norm_fn == "none": self.norm1 = nn.Sequential() self.norm2 = nn.Sequential() if not stride == 1: self.norm3 = nn.Sequential() if stride == 1: self.downsample = None else: self.downsample = nn.Sequential( nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 ) def forward(self, x): y = x y = self.relu(self.norm1(self.conv1(y))) y = self.relu(self.norm2(self.conv2(y))) if self.downsample is not None: x = self.downsample(x) return self.relu(x + y) class BasicEncoder(nn.Module): def __init__(self, input_dim=3, output_dim=128, stride=4): super(BasicEncoder, self).__init__() self.stride = stride self.norm_fn = "instance" self.in_planes = output_dim // 2 self.norm1 = nn.InstanceNorm2d(self.in_planes) self.norm2 = nn.InstanceNorm2d(output_dim * 2) self.conv1 = nn.Conv2d( input_dim, self.in_planes, kernel_size=7, stride=2, padding=3, padding_mode="zeros", ) self.relu1 = nn.ReLU(inplace=True) self.layer1 = self._make_layer(output_dim // 2, stride=1) self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) self.layer3 = self._make_layer(output_dim, stride=2) self.layer4 = self._make_layer(output_dim, stride=2) self.conv2 = nn.Conv2d( output_dim * 3 + output_dim // 4, output_dim * 2, kernel_size=3, padding=1, padding_mode="zeros", ) self.relu2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.InstanceNorm2d)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def _make_layer(self, dim, stride=1): layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) layers = (layer1, layer2) self.in_planes = dim return nn.Sequential(*layers) def forward(self, x): _, _, H, W = x.shape x = self.conv1(x) x = self.norm1(x) x = self.relu1(x) a = self.layer1(x) b = self.layer2(a) c = self.layer3(b) d = self.layer4(c) def _bilinear_intepolate(x): return F.interpolate( x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True, ) a = _bilinear_intepolate(a) b = _bilinear_intepolate(b) c = _bilinear_intepolate(c) d = _bilinear_intepolate(d) x = self.conv2(torch.cat([a, b, c, d], dim=1)) x = self.norm2(x) x = self.relu2(x) x = self.conv3(x) return x class EfficientUpdateFormer(nn.Module): """ Transformer model that updates track estimates. """ def __init__( self, space_depth=6, time_depth=6, input_dim=320, hidden_size=384, num_heads=8, output_dim=130, mlp_ratio=4.0, num_virtual_tracks=64, add_space_attn=True, linear_layer_for_vis_conf=False, use_time_conv=False, use_time_mixer=False, ): super().__init__() self.out_channels = 2 self.num_heads = num_heads self.hidden_size = hidden_size self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) if linear_layer_for_vis_conf: self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True) self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True) else: self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) self.num_virtual_tracks = num_virtual_tracks self.virual_tracks = nn.Parameter( torch.randn(1, num_virtual_tracks, 1, hidden_size) ) self.add_space_attn = add_space_attn self.linear_layer_for_vis_conf = linear_layer_for_vis_conf if use_time_conv: self.time_blocks = nn.ModuleList( [ CNBlock1d(hidden_size, hidden_size, dense=False) for _ in range(time_depth) ] ) elif use_time_mixer: self.time_blocks = nn.ModuleList( [ MLPMixerBlock( S=16, dim=hidden_size, depth=1, ) for _ in range(time_depth) ] ) else: self.time_blocks = nn.ModuleList( [ AttnBlock( hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=Attention, ) for _ in range(time_depth) ] ) if add_space_attn: self.space_virtual_blocks = nn.ModuleList( [ AttnBlock( hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=Attention, ) for _ in range(space_depth) ] ) self.space_point2virtual_blocks = nn.ModuleList( [ CrossAttnBlock( hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio ) for _ in range(space_depth) ] ) self.space_virtual2point_blocks = nn.ModuleList( [ CrossAttnBlock( hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio ) for _ in range(space_depth) ] ) assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) self.initialize_weights() def initialize_weights(self): def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) if self.linear_layer_for_vis_conf: torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001) def _trunc_init(module): """ViT weight initialization, original timm impl (for reproducibility)""" if isinstance(module, nn.Linear): torch.nn.init.trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) self.apply(_basic_init) def forward(self, input_tensor, mask=None, add_space_attn=True): tokens = self.input_transform(input_tensor) B, _, T, _ = tokens.shape virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) tokens = torch.cat([tokens, virtual_tokens], dim=1) _, N, _, _ = tokens.shape j = 0 layers = [] for i in range(len(self.time_blocks)): time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C time_tokens = self.time_blocks[i](time_tokens) tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C if ( add_space_attn and hasattr(self, "space_virtual_blocks") and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0) ): space_tokens = ( tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) ) # B N T C -> (B T) N C point_tokens = space_tokens[:, : N - self.num_virtual_tracks] virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] virtual_tokens = self.space_virtual2point_blocks[j]( virtual_tokens, point_tokens, mask=mask ) virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) point_tokens = self.space_point2virtual_blocks[j]( point_tokens, virtual_tokens, mask=mask ) space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) tokens = space_tokens.view(B, T, N, -1).permute( 0, 2, 1, 3 ) # (B T) N C -> B N T C j += 1 tokens = tokens[:, : N - self.num_virtual_tracks] flow = self.flow_head(tokens) if self.linear_layer_for_vis_conf: vis_conf = self.vis_conf_head(tokens) flow = torch.cat([flow, vis_conf], dim=-1) return flow class MMPreNormResidual(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) def forward(self, x): return self.fn(self.norm(x)) + x def MMFeedForward(dim, expansion_factor=4, dropout=0., dense=nn.Linear): return nn.Sequential( dense(dim, dim * expansion_factor), nn.GELU(), nn.Dropout(dropout), dense(dim * expansion_factor, dim), nn.Dropout(dropout) ) def MLPMixer(S, input_dim, dim, output_dim, depth=6, expansion_factor=4, dropout=0., do_reduce=False): # input is coming in as B,S,C, as standard for mlp and transformer # chan_first treats S as the channel dim, and transforms it to a new S # chan_last treats C as the channel dim, and transforms it to a new C chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear if do_reduce: return nn.Sequential( nn.Linear(input_dim, dim), *[nn.Sequential( MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)), MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last)) ) for _ in range(depth)], nn.LayerNorm(dim), Reduce('b n c -> b c', 'mean'), nn.Linear(dim, output_dim) ) else: return nn.Sequential( nn.Linear(input_dim, dim), *[nn.Sequential( MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)), MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last)) ) for _ in range(depth)], ) def MLPMixerBlock(S, dim, depth=1, expansion_factor=4, dropout=0., do_reduce=False): # input is coming in as B,S,C, as standard for mlp and transformer # chan_first treats S as the channel dim, and transforms it to a new S # chan_last treats C as the channel dim, and transforms it to a new C chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear return nn.Sequential( *[nn.Sequential( MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)), MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last)) ) for _ in range(depth)], ) class MlpUpdateFormer(nn.Module): """ Transformer model that updates track estimates. """ def __init__( self, space_depth=6, time_depth=6, input_dim=320, hidden_size=384, num_heads=8, output_dim=130, mlp_ratio=4.0, num_virtual_tracks=64, add_space_attn=True, linear_layer_for_vis_conf=False, ): super().__init__() self.out_channels = 2 self.num_heads = num_heads self.hidden_size = hidden_size self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) if linear_layer_for_vis_conf: self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True) self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True) else: self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) self.num_virtual_tracks = num_virtual_tracks self.virual_tracks = nn.Parameter( torch.randn(1, num_virtual_tracks, 1, hidden_size) ) self.add_space_attn = add_space_attn self.linear_layer_for_vis_conf = linear_layer_for_vis_conf self.time_blocks = nn.ModuleList( [ MLPMixer( S=16, input_dim=hidden_size, dim=hidden_size, output_dim=hidden_size, depth=1, ) for _ in range(time_depth) ] ) if add_space_attn: self.space_virtual_blocks = nn.ModuleList( [ AttnBlock( hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=Attention, ) for _ in range(space_depth) ] ) self.space_point2virtual_blocks = nn.ModuleList( [ CrossAttnBlock( hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio ) for _ in range(space_depth) ] ) self.space_virtual2point_blocks = nn.ModuleList( [ CrossAttnBlock( hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio ) for _ in range(space_depth) ] ) assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) self.initialize_weights() def initialize_weights(self): def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) if self.linear_layer_for_vis_conf: torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001) def _trunc_init(module): """ViT weight initialization, original timm impl (for reproducibility)""" if isinstance(module, nn.Linear): torch.nn.init.trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) self.apply(_basic_init) def forward(self, input_tensor, mask=None, add_space_attn=True): tokens = self.input_transform(input_tensor) B, _, T, _ = tokens.shape virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) tokens = torch.cat([tokens, virtual_tokens], dim=1) _, N, _, _ = tokens.shape j = 0 layers = [] for i in range(len(self.time_blocks)): time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C time_tokens = self.time_blocks[i](time_tokens) tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C if ( add_space_attn and hasattr(self, "space_virtual_blocks") and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0) ): space_tokens = ( tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) ) # B N T C -> (B T) N C point_tokens = space_tokens[:, : N - self.num_virtual_tracks] virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] virtual_tokens = self.space_virtual2point_blocks[j]( virtual_tokens, point_tokens, mask=mask ) virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) point_tokens = self.space_point2virtual_blocks[j]( point_tokens, virtual_tokens, mask=mask ) space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) tokens = space_tokens.view(B, T, N, -1).permute( 0, 2, 1, 3 ) # (B T) N C -> B N T C j += 1 tokens = tokens[:, : N - self.num_virtual_tracks] flow = self.flow_head(tokens) if self.linear_layer_for_vis_conf: vis_conf = self.vis_conf_head(tokens) flow = torch.cat([flow, vis_conf], dim=-1) return flow class BasicMotionEncoder(nn.Module): def __init__(self, corr_channel, dim=128, pdim=2): super(BasicMotionEncoder, self).__init__() self.pdim = pdim self.convc1 = nn.Conv2d(corr_channel, dim*4, 1, padding=0) self.convc2 = nn.Conv2d(dim*4, dim+dim//2, 3, padding=1) if pdim==2 or pdim==4: self.convf1 = nn.Conv2d(pdim, dim*2, 5, padding=2) self.convf2 = nn.Conv2d(dim*2, dim//2, 3, padding=1) self.conv = nn.Conv2d(dim*2, dim-pdim, 3, padding=1) else: self.conv = nn.Conv2d(dim+dim//2+pdim, dim, 3, padding=1) def forward(self, flow, corr): cor = F.relu(self.convc1(corr)) cor = F.relu(self.convc2(cor)) if self.pdim==2 or self.pdim==4: flo = F.relu(self.convf1(flow)) flo = F.relu(self.convf2(flo)) cor_flo = torch.cat([cor, flo], dim=1) out = F.relu(self.conv(cor_flo)) return torch.cat([out, flow], dim=1) else: # the flow is already encoded to something nice cor_flo = torch.cat([cor, flow], dim=1) return F.relu(self.conv(cor_flo)) # return torch.cat([out, flow], dim=1) def conv133_encoder(input_dim, dim, expansion_factor=4): return nn.Sequential( nn.Conv2d(input_dim, dim*expansion_factor, kernel_size=1), nn.GELU(), nn.Conv2d(dim*expansion_factor, dim*expansion_factor, kernel_size=3, padding=1), nn.GELU(), nn.Conv2d(dim*expansion_factor, dim, kernel_size=3, padding=1), ) class BasicUpdateBlock(nn.Module): def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128): # flowfeat is hdim; ctxfeat is dim. typically hdim==cdim. super(BasicUpdateBlock, self).__init__() self.encoder = BasicMotionEncoder(corr_channel, dim=cdim) self.compressor = conv1x1(2*cdim+hdim, hdim) self.refine = [] for i in range(num_blocks): self.refine.append(CNBlock1d(hdim, hdim)) self.refine.append(CNBlock2d(hdim, hdim)) self.refine = nn.ModuleList(self.refine) def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True): BS,C,H,W = flowfeat.shape B = BS//S # with torch.no_grad(): motion_features = self.encoder(flow, corr) flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1)) for blk in self.refine: flowfeat = blk(flowfeat, S) return flowfeat class FullUpdateBlock(nn.Module): def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=2, use_attn=False): # flowfeat is hdim; ctxfeat is dim. typically hdim==cdim. super(FullUpdateBlock, self).__init__() self.encoder = BasicMotionEncoder(corr_channel, dim=cdim, pdim=pdim) # note we have hdim==cdim # compressor chans: # dim for flowfeat # dim for ctxfeat # dim for motion_features # pdim for flow (if p 2, like if we give sincos(relflow)) # 2 for visconf if pdim==2: # hdim==cdim # dim for flowfeat # dim for ctxfeat # dim for motion_features # 2 for visconf self.compressor = conv1x1(2*cdim+hdim+2, hdim) else: # we concatenate the flow info again, to not lose it (e.g., from the relu) self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim) self.refine = [] for i in range(num_blocks): self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn)) self.refine.append(CNBlock2d(hdim, hdim)) self.refine = nn.ModuleList(self.refine) def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True): BS,C,H,W = flowfeat.shape B = BS//S motion_features = self.encoder(flow, corr) flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1)) for blk in self.refine: flowfeat = blk(flowfeat, S) return flowfeat class MixerUpdateBlock(nn.Module): def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128): # flowfeat is hdim; ctxfeat is dim. typically hdim==cdim. super(MixerUpdateBlock, self).__init__() self.encoder = BasicMotionEncoder(corr_channel, dim=cdim) self.compressor = conv1x1(2*cdim+hdim, hdim) self.refine = [] for i in range(num_blocks): self.refine.append(CNBlock1d(hdim, hdim, use_mixer=True)) self.refine.append(CNBlock2d(hdim, hdim)) self.refine = nn.ModuleList(self.refine) def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True): BS,C,H,W = flowfeat.shape B = BS//S # with torch.no_grad(): motion_features = self.encoder(flow, corr) flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1)) for ii, blk in enumerate(self.refine): flowfeat = blk(flowfeat, S) return flowfeat class FacUpdateBlock(nn.Module): def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=84, use_attn=False): super(FacUpdateBlock, self).__init__() self.corr_encoder = conv133_encoder(corr_channel, cdim) # note we have hdim==cdim # compressor chans: # dim for flowfeat # dim for ctxfeat # dim for corr # pdim for flow # 2 for visconf self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim) self.refine = [] for i in range(num_blocks): self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn)) self.refine.append(CNBlock2d(hdim, hdim)) self.refine = nn.ModuleList(self.refine) def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True): BS,C,H,W = flowfeat.shape B = BS//S corr = self.corr_encoder(corr) flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corr, visconf, flow], dim=1)) for blk in self.refine: flowfeat = blk(flowfeat, S) return flowfeat class CleanUpdateBlock(nn.Module): def __init__(self, corr_channel, num_blocks, cdim=128, hdim=256, pdim=84, use_attn=False, use_layer_scale=True): super(CleanUpdateBlock, self).__init__() self.corr_encoder = conv133_encoder(corr_channel, cdim) # compressor chans: # cdim for flowfeat # cdim for ctxfeat # cdim for corrfeat # pdim for flow # 2 for visconf self.compressor = conv1x1(3*cdim+pdim+2, hdim) self.refine = [] for i in range(num_blocks): self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_layer_scale=use_layer_scale)) self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale)) self.refine = nn.ModuleList(self.refine) self.final_conv = conv1x1(hdim, cdim) def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True): BS,C,H,W = flowfeat.shape B = BS//S corrfeat = self.corr_encoder(corr) flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corrfeat, flow, visconf], dim=1)) for blk in self.refine: flowfeat = blk(flowfeat, S) flowfeat = self.final_conv(flowfeat) return flowfeat class RelUpdateBlock(nn.Module): def __init__(self, corr_channel, num_blocks, cdim=128, hdim=128, pdim=4, use_attn=True, use_mixer=False, use_conv=False, use_convb=False, use_layer_scale=True, no_time=False, no_space=False, no_ctx=False): super(RelUpdateBlock, self).__init__() self.motion_encoder = BasicMotionEncoder(corr_channel, dim=hdim, pdim=pdim) # B,hdim,H,W self.no_ctx = no_ctx if no_ctx: self.compressor = conv1x1(cdim+hdim+2, hdim) else: self.compressor = conv1x1(2*cdim+hdim+2, hdim) self.refine = [] for i in range(num_blocks): if not no_time: self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_mixer=use_mixer, use_conv=use_conv, use_convb=use_convb, use_layer_scale=use_layer_scale)) if not no_space: self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale)) self.refine = nn.ModuleList(self.refine) self.final_conv = conv1x1(hdim, cdim) def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True): BS,C,H,W = flowfeat.shape B = BS//S motion_features = self.motion_encoder(flow, corr) if self.no_ctx: flowfeat = self.compressor(torch.cat([flowfeat, motion_features, visconf], dim=1)) else: flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1)) for blk in self.refine: flowfeat = blk(flowfeat, S) flowfeat = self.final_conv(flowfeat) return flowfeat