alltracker / nets /blocks.py
aharley's picture
added basics
77a88de
raw
history blame
48.9 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn, Tensor
from itertools import repeat
import collections
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
from functools import partial
import einops
import math
from torchvision.ops.misc import Conv2dNormActivation, Permute
from torchvision.ops.stochastic_depth import StochasticDepth
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
to_2tuple = _ntuple(2)
class InputPadder:
""" Pads images such that dimensions are divisible by a certain stride """
def __init__(self, dims, mode='sintel'):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 64) + 1) * 64 - self.ht) % 64
pad_wd = (((self.wd // 64) + 1) * 64 - self.wd) % 64
if mode == 'sintel':
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
else:
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]
def bilinear_sampler(
input, coords,
align_corners=True,
padding_mode="border",
normalize_coords=True):
# func from mattie (oct9)
if input.ndim not in [4, 5]:
raise ValueError("input must be 4D or 5D.")
if input.ndim == 4 and not coords.ndim == 4:
raise ValueError("input is 4D, but coords is not 4D.")
if input.ndim == 5 and not coords.ndim == 5:
raise ValueError("input is 5D, but coords is not 5D.")
if coords.ndim == 5:
coords = coords[..., [1, 2, 0]] # t x y -> x y t to match what grid_sample() expects.
if normalize_coords:
if align_corners:
# Normalize coordinates from [0, W/H - 1] to [-1, 1].
coords = (
coords
* torch.tensor([2 / max(size - 1, 1) for size in reversed(input.shape[2:])], device=coords.device)
- 1
)
else:
# Normalize coordinates from [0, W/H] to [-1, 1].
coords = coords * torch.tensor([2 / size for size in reversed(input.shape[2:])], device=coords.device) - 1
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
class CorrBlock:
def __init__(self, fmap1, fmap2, corr_levels, corr_radius):
self.num_levels = corr_levels
self.radius = corr_radius
self.corr_pyramid = []
# all pairs correlation
for i in range(self.num_levels):
corr = CorrBlock.corr(fmap1, fmap2, 1)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
fmap2 = F.interpolate(fmap2, scale_factor=0.5, mode='area')
# print('corr', corr.shape)
self.corr_pyramid.append(corr)
def __call__(self, coords, dilation=None):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
if dilation is None:
dilation = torch.ones(batch, 1, h1, w1, device=coords.device)
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
device = coords.device
dx = torch.linspace(-r, r, 2*r+1, device=device)
dy = torch.linspace(-r, r, 2*r+1, device=device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
delta_lvl = delta_lvl * dilation.view(batch * h1 * w1, 1, 1, 1)
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
out = out.permute(0, 3, 1, 2).contiguous().float()
return out
@staticmethod
def corr(fmap1, fmap2, num_head):
batch, dim, h1, w1 = fmap1.shape
h2, w2 = fmap2.shape[2:]
fmap1 = fmap1.view(batch, num_head, dim // num_head, h1*w1)
fmap2 = fmap2.view(batch, num_head, dim // num_head, h2*w2)
corr = fmap1.transpose(2, 3) @ fmap2
corr = corr.reshape(batch, num_head, h1, w1, h2, w2).permute(0, 2, 3, 1, 4, 5)
return corr / torch.sqrt(torch.tensor(dim).float())
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution without padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1)
class LayerNorm2d(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
x = x.permute(0, 2, 3, 1)
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
return x
class CNBlock1d(nn.Module):
def __init__(
self,
dim,
output_dim,
layer_scale: float = 1e-6,
stochastic_depth_prob: float = 0,
norm_layer: Optional[Callable[..., nn.Module]] = None,
dense=True,
use_attn=True,
use_mixer=False,
use_conv=False,
use_convb=False,
use_layer_scale=True,
) -> None:
super().__init__()
self.dense = dense
self.use_attn = use_attn
self.use_mixer = use_mixer
self.use_conv = use_conv
self.use_layer_scale = use_layer_scale
if use_attn:
assert not use_mixer
assert not use_conv
assert not use_convb
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
if use_attn:
num_heads = 8
self.block = AttnBlock(
hidden_size=dim,
num_heads=num_heads,
mlp_ratio=4,
attn_class=Attention,
)
elif use_mixer:
self.block = MLPMixerBlock(
S=16,
dim=dim,
depth=1,
expansion_factor=2,
)
elif use_conv:
self.block = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'),
Permute([0, 2, 1]),
norm_layer(dim),
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
nn.GELU(),
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
Permute([0, 2, 1]),
)
elif use_convb:
self.block = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size=3, padding=1, bias=True, padding_mode='zeros'),
Permute([0, 2, 1]),
norm_layer(dim),
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
nn.GELU(),
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
Permute([0, 2, 1]),
)
else:
assert(False) # choose attn, mixer, or conv please
if self.use_layer_scale:
self.layer_scale = nn.Parameter(torch.ones(dim, 1) * layer_scale)
else:
self.layer_scale = 1.0
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
if output_dim != dim:
self.final = nn.Conv1d(dim, output_dim, kernel_size=1, padding=0)
else:
self.final = nn.Identity()
def forward(self, input, S=None):
if self.dense:
assert S is not None
BS,C,H,W = input.shape
B = BS//S
input = einops.rearrange(input, '(b s) c h w -> (b h w) c s', b=B, s=S, c=C, h=H, w=W)
if self.use_mixer or self.use_attn:
# mixer/transformer blocks want B,S,C
result = self.layer_scale * self.block(input.permute(0,2,1)).permute(0,2,1)
else:
result = self.layer_scale * self.block(input)
result = self.stochastic_depth(result)
result += input
result = self.final(result)
result = einops.rearrange(result, '(b h w) c s -> (b s) c h w', b=B, s=S, c=C, h=H, w=W)
else:
B,S,C = input.shape
if S<7:
return input
input = einops.rearrange(input, 'b s c -> b c s', b=B, s=S, c=C)
result = self.layer_scale * self.block(input)
result = self.stochastic_depth(result)
result += input
result = self.final(result)
result = einops.rearrange(result, 'b c s -> b s c', b=B, s=S, c=C)
return result
class CNBlock2d(nn.Module):
def __init__(
self,
dim,
output_dim,
layer_scale: float = 1e-6,
stochastic_depth_prob: float = 0,
norm_layer: Optional[Callable[..., nn.Module]] = None,
use_layer_scale=True,
) -> None:
super().__init__()
self.use_layer_scale = use_layer_scale
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.block = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'),
Permute([0, 2, 3, 1]),
norm_layer(dim),
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
nn.GELU(),
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
Permute([0, 3, 1, 2]),
)
if self.use_layer_scale:
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
else:
self.layer_scale = 1.0
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
if output_dim != dim:
self.final = nn.Conv2d(dim, output_dim, kernel_size=1, padding=0)
else:
self.final = nn.Identity()
def forward(self, input, S=None):
result = self.layer_scale * self.block(input)
result = self.stochastic_depth(result)
result += input
result = self.final(result)
return result
class CNBlockConfig:
# Stores information listed at Section 3 of the ConvNeXt paper
def __init__(
self,
input_channels: int,
out_channels: Optional[int],
num_layers: int,
downsample: bool,
) -> None:
self.input_channels = input_channels
self.out_channels = out_channels
self.num_layers = num_layers
self.downsample = downsample
def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "input_channels={input_channels}"
s += ", out_channels={out_channels}"
s += ", num_layers={num_layers}"
s += ", downsample={downsample}"
s += ")"
return s.format(**self.__dict__)
class ConvNeXt(nn.Module):
def __init__(
self,
block_setting: List[CNBlockConfig],
stochastic_depth_prob: float = 0.0,
layer_scale: float = 1e-6,
num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
init_weights=True):
super().__init__()
self.init_weights = init_weights
if not block_setting:
raise ValueError("The block_setting should not be empty")
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
raise TypeError("The block_setting should be List[CNBlockConfig]")
if block is None:
block = CNBlock2d
if norm_layer is None:
norm_layer = partial(LayerNorm2d, eps=1e-6)
layers: List[nn.Module] = []
# Stem
firstconv_output_channels = block_setting[0].input_channels
layers.append(
Conv2dNormActivation(
3,
firstconv_output_channels,
kernel_size=4,
stride=4,
padding=0,
norm_layer=norm_layer,
activation_layer=None,
bias=True,
)
)
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
stage_block_id = 0
for cnf in block_setting:
# Bottlenecks
stage: List[nn.Module] = []
for _ in range(cnf.num_layers):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
stage.append(block(cnf.input_channels, cnf.input_channels, layer_scale, sd_prob))
stage_block_id += 1
layers.append(nn.Sequential(*stage))
if cnf.out_channels is not None:
if cnf.downsample:
layers.append(
nn.Sequential(
norm_layer(cnf.input_channels),
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
)
)
else:
# we convert the 2x2 downsampling layer into a 3x3 with dilation2 and replicate padding.
# replicate padding compensates for the fact that this kernel never saw zero-padding.
layers.append(
nn.Sequential(
norm_layer(cnf.input_channels),
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=3, stride=1, padding=2, dilation=2, padding_mode='zeros'),
)
)
self.features = nn.Sequential(*layers)
# self.final_conv = conv1x1(block_setting[-1].input_channels, output_dim)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
if self.init_weights:
from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights
pretrained_dict = convnext_tiny(weights=ConvNeXt_Tiny_Weights.DEFAULT).state_dict()
# from torchvision.models import convnext_base, ConvNeXt_Base_Weights
# pretrained_dict = convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT).state_dict()
model_dict = self.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
for k, v in pretrained_dict.items():
if k == 'features.4.1.weight': # this is the layer normally in charge of 2x2 downsampling
# convert to 3x3 filter
pretrained_dict[k] = F.interpolate(v, (3, 3), mode='bicubic', align_corners=True) * (4/9.0)
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict, strict=False)
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.features(x)
# x = self.final_conv(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
class Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = (
norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
)
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class Attention(nn.Module):
def __init__(
self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False
):
super().__init__()
inner_dim = dim_head * num_heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = num_heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
self.to_out = nn.Linear(inner_dim, query_dim)
def forward(self, x, context=None, attn_bias=None):
B, N1, C = x.shape
H = self.heads
q = self.to_q(x)
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim=-1)
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))
x = F.scaled_dot_product_attention(q, k, v) # scale default is already dim^-0.5
x = einops.rearrange(x, 'b h n d -> b n (h d)')
return self.to_out(x)
class CrossAttnBlock(nn.Module):
def __init__(
self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs
):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm_context = nn.LayerNorm(hidden_size)
self.cross_attn = Attention(
hidden_size,
context_dim=context_dim,
num_heads=num_heads,
qkv_bias=True,
**block_kwargs
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
def forward(self, x, context, mask=None):
attn_bias = None
if mask is not None:
if mask.shape[1] == x.shape[1]:
mask = mask[:, None, :, None].expand(
-1, self.cross_attn.heads, -1, context.shape[1]
)
else:
mask = mask[:, None, None].expand(
-1, self.cross_attn.heads, x.shape[1], -1
)
max_neg_value = -torch.finfo(x.dtype).max
attn_bias = (~mask) * max_neg_value
x = x + self.cross_attn(
self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
)
x = x + self.mlp(self.norm2(x))
return x
class AttnBlock(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
attn_class: Callable[..., nn.Module] = Attention,
mlp_ratio=4.0,
**block_kwargs
):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, dim_head=hidden_size//num_heads)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
def forward(self, x, mask=None):
attn_bias = mask
if mask is not None:
mask = (
(mask[:, None] * mask[:, :, None])
.unsqueeze(1)
.expand(-1, self.attn.num_heads, -1, -1)
)
max_neg_value = -torch.finfo(x.dtype).max
attn_bias = (~mask) * max_neg_value
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
x = x + self.mlp(self.norm2(x))
return x
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
padding=1,
stride=stride,
padding_mode="zeros",
)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
def __init__(self, input_dim=3, output_dim=128, stride=4):
super(BasicEncoder, self).__init__()
self.stride = stride
self.norm_fn = "instance"
self.in_planes = output_dim // 2
self.norm1 = nn.InstanceNorm2d(self.in_planes)
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
self.conv1 = nn.Conv2d(
input_dim,
self.in_planes,
kernel_size=7,
stride=2,
padding=3,
padding_mode="zeros",
)
self.relu1 = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(output_dim // 2, stride=1)
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
self.layer3 = self._make_layer(output_dim, stride=2)
self.layer4 = self._make_layer(output_dim, stride=2)
self.conv2 = nn.Conv2d(
output_dim * 3 + output_dim // 4,
output_dim * 2,
kernel_size=3,
padding=1,
padding_mode="zeros",
)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.InstanceNorm2d)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
_, _, H, W = x.shape
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
d = self.layer4(c)
def _bilinear_intepolate(x):
return F.interpolate(
x,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
a = _bilinear_intepolate(a)
b = _bilinear_intepolate(b)
c = _bilinear_intepolate(c)
d = _bilinear_intepolate(d)
x = self.conv2(torch.cat([a, b, c, d], dim=1))
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
return x
class EfficientUpdateFormer(nn.Module):
"""
Transformer model that updates track estimates.
"""
def __init__(
self,
space_depth=6,
time_depth=6,
input_dim=320,
hidden_size=384,
num_heads=8,
output_dim=130,
mlp_ratio=4.0,
num_virtual_tracks=64,
add_space_attn=True,
linear_layer_for_vis_conf=False,
use_time_conv=False,
use_time_mixer=False,
):
super().__init__()
self.out_channels = 2
self.num_heads = num_heads
self.hidden_size = hidden_size
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
if linear_layer_for_vis_conf:
self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
else:
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
self.num_virtual_tracks = num_virtual_tracks
self.virual_tracks = nn.Parameter(
torch.randn(1, num_virtual_tracks, 1, hidden_size)
)
self.add_space_attn = add_space_attn
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
if use_time_conv:
self.time_blocks = nn.ModuleList(
[
CNBlock1d(hidden_size, hidden_size, dense=False)
for _ in range(time_depth)
]
)
elif use_time_mixer:
self.time_blocks = nn.ModuleList(
[
MLPMixerBlock(
S=16,
dim=hidden_size,
depth=1,
)
for _ in range(time_depth)
]
)
else:
self.time_blocks = nn.ModuleList(
[
AttnBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
attn_class=Attention,
)
for _ in range(time_depth)
]
)
if add_space_attn:
self.space_virtual_blocks = nn.ModuleList(
[
AttnBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
attn_class=Attention,
)
for _ in range(space_depth)
]
)
self.space_point2virtual_blocks = nn.ModuleList(
[
CrossAttnBlock(
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
)
for _ in range(space_depth)
]
)
self.space_virtual2point_blocks = nn.ModuleList(
[
CrossAttnBlock(
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
)
for _ in range(space_depth)
]
)
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
if self.linear_layer_for_vis_conf:
torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
def _trunc_init(module):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
torch.nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
self.apply(_basic_init)
def forward(self, input_tensor, mask=None, add_space_attn=True):
tokens = self.input_transform(input_tensor)
B, _, T, _ = tokens.shape
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
tokens = torch.cat([tokens, virtual_tokens], dim=1)
_, N, _, _ = tokens.shape
j = 0
layers = []
for i in range(len(self.time_blocks)):
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
time_tokens = self.time_blocks[i](time_tokens)
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
if (
add_space_attn
and hasattr(self, "space_virtual_blocks")
and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
):
space_tokens = (
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
) # B N T C -> (B T) N C
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
virtual_tokens = self.space_virtual2point_blocks[j](
virtual_tokens, point_tokens, mask=mask
)
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
point_tokens = self.space_point2virtual_blocks[j](
point_tokens, virtual_tokens, mask=mask
)
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
tokens = space_tokens.view(B, T, N, -1).permute(
0, 2, 1, 3
) # (B T) N C -> B N T C
j += 1
tokens = tokens[:, : N - self.num_virtual_tracks]
flow = self.flow_head(tokens)
if self.linear_layer_for_vis_conf:
vis_conf = self.vis_conf_head(tokens)
flow = torch.cat([flow, vis_conf], dim=-1)
return flow
class MMPreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
return self.fn(self.norm(x)) + x
def MMFeedForward(dim, expansion_factor=4, dropout=0., dense=nn.Linear):
return nn.Sequential(
dense(dim, dim * expansion_factor),
nn.GELU(),
nn.Dropout(dropout),
dense(dim * expansion_factor, dim),
nn.Dropout(dropout)
)
def MLPMixer(S, input_dim, dim, output_dim, depth=6, expansion_factor=4, dropout=0., do_reduce=False):
# input is coming in as B,S,C, as standard for mlp and transformer
# chan_first treats S as the channel dim, and transforms it to a new S
# chan_last treats C as the channel dim, and transforms it to a new C
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
if do_reduce:
return nn.Sequential(
nn.Linear(input_dim, dim),
*[nn.Sequential(
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
) for _ in range(depth)],
nn.LayerNorm(dim),
Reduce('b n c -> b c', 'mean'),
nn.Linear(dim, output_dim)
)
else:
return nn.Sequential(
nn.Linear(input_dim, dim),
*[nn.Sequential(
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
) for _ in range(depth)],
)
def MLPMixerBlock(S, dim, depth=1, expansion_factor=4, dropout=0., do_reduce=False):
# input is coming in as B,S,C, as standard for mlp and transformer
# chan_first treats S as the channel dim, and transforms it to a new S
# chan_last treats C as the channel dim, and transforms it to a new C
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
return nn.Sequential(
*[nn.Sequential(
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
) for _ in range(depth)],
)
class MlpUpdateFormer(nn.Module):
"""
Transformer model that updates track estimates.
"""
def __init__(
self,
space_depth=6,
time_depth=6,
input_dim=320,
hidden_size=384,
num_heads=8,
output_dim=130,
mlp_ratio=4.0,
num_virtual_tracks=64,
add_space_attn=True,
linear_layer_for_vis_conf=False,
):
super().__init__()
self.out_channels = 2
self.num_heads = num_heads
self.hidden_size = hidden_size
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
if linear_layer_for_vis_conf:
self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
else:
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
self.num_virtual_tracks = num_virtual_tracks
self.virual_tracks = nn.Parameter(
torch.randn(1, num_virtual_tracks, 1, hidden_size)
)
self.add_space_attn = add_space_attn
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
self.time_blocks = nn.ModuleList(
[
MLPMixer(
S=16,
input_dim=hidden_size,
dim=hidden_size,
output_dim=hidden_size,
depth=1,
)
for _ in range(time_depth)
]
)
if add_space_attn:
self.space_virtual_blocks = nn.ModuleList(
[
AttnBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
attn_class=Attention,
)
for _ in range(space_depth)
]
)
self.space_point2virtual_blocks = nn.ModuleList(
[
CrossAttnBlock(
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
)
for _ in range(space_depth)
]
)
self.space_virtual2point_blocks = nn.ModuleList(
[
CrossAttnBlock(
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
)
for _ in range(space_depth)
]
)
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
if self.linear_layer_for_vis_conf:
torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
def _trunc_init(module):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
torch.nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
self.apply(_basic_init)
def forward(self, input_tensor, mask=None, add_space_attn=True):
tokens = self.input_transform(input_tensor)
B, _, T, _ = tokens.shape
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
tokens = torch.cat([tokens, virtual_tokens], dim=1)
_, N, _, _ = tokens.shape
j = 0
layers = []
for i in range(len(self.time_blocks)):
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
time_tokens = self.time_blocks[i](time_tokens)
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
if (
add_space_attn
and hasattr(self, "space_virtual_blocks")
and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
):
space_tokens = (
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
) # B N T C -> (B T) N C
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
virtual_tokens = self.space_virtual2point_blocks[j](
virtual_tokens, point_tokens, mask=mask
)
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
point_tokens = self.space_point2virtual_blocks[j](
point_tokens, virtual_tokens, mask=mask
)
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
tokens = space_tokens.view(B, T, N, -1).permute(
0, 2, 1, 3
) # (B T) N C -> B N T C
j += 1
tokens = tokens[:, : N - self.num_virtual_tracks]
flow = self.flow_head(tokens)
if self.linear_layer_for_vis_conf:
vis_conf = self.vis_conf_head(tokens)
flow = torch.cat([flow, vis_conf], dim=-1)
return flow
class BasicMotionEncoder(nn.Module):
def __init__(self, corr_channel, dim=128, pdim=2):
super(BasicMotionEncoder, self).__init__()
self.pdim = pdim
self.convc1 = nn.Conv2d(corr_channel, dim*4, 1, padding=0)
self.convc2 = nn.Conv2d(dim*4, dim+dim//2, 3, padding=1)
if pdim==2 or pdim==4:
self.convf1 = nn.Conv2d(pdim, dim*2, 5, padding=2)
self.convf2 = nn.Conv2d(dim*2, dim//2, 3, padding=1)
self.conv = nn.Conv2d(dim*2, dim-pdim, 3, padding=1)
else:
self.conv = nn.Conv2d(dim+dim//2+pdim, dim, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
if self.pdim==2 or self.pdim==4:
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
else:
# the flow is already encoded to something nice
cor_flo = torch.cat([cor, flow], dim=1)
return F.relu(self.conv(cor_flo))
# return torch.cat([out, flow], dim=1)
def conv133_encoder(input_dim, dim, expansion_factor=4):
return nn.Sequential(
nn.Conv2d(input_dim, dim*expansion_factor, kernel_size=1),
nn.GELU(),
nn.Conv2d(dim*expansion_factor, dim*expansion_factor, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv2d(dim*expansion_factor, dim, kernel_size=3, padding=1),
)
class BasicUpdateBlock(nn.Module):
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128):
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
super(BasicUpdateBlock, self).__init__()
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim)
self.compressor = conv1x1(2*cdim+hdim, hdim)
self.refine = []
for i in range(num_blocks):
self.refine.append(CNBlock1d(hdim, hdim))
self.refine.append(CNBlock2d(hdim, hdim))
self.refine = nn.ModuleList(self.refine)
def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True):
BS,C,H,W = flowfeat.shape
B = BS//S
# with torch.no_grad():
motion_features = self.encoder(flow, corr)
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1))
for blk in self.refine:
flowfeat = blk(flowfeat, S)
return flowfeat
class FullUpdateBlock(nn.Module):
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=2, use_attn=False):
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
super(FullUpdateBlock, self).__init__()
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim, pdim=pdim)
# note we have hdim==cdim
# compressor chans:
# dim for flowfeat
# dim for ctxfeat
# dim for motion_features
# pdim for flow (if p 2, like if we give sincos(relflow))
# 2 for visconf
if pdim==2:
# hdim==cdim
# dim for flowfeat
# dim for ctxfeat
# dim for motion_features
# 2 for visconf
self.compressor = conv1x1(2*cdim+hdim+2, hdim)
else:
# we concatenate the flow info again, to not lose it (e.g., from the relu)
self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim)
self.refine = []
for i in range(num_blocks):
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn))
self.refine.append(CNBlock2d(hdim, hdim))
self.refine = nn.ModuleList(self.refine)
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
BS,C,H,W = flowfeat.shape
B = BS//S
motion_features = self.encoder(flow, corr)
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1))
for blk in self.refine:
flowfeat = blk(flowfeat, S)
return flowfeat
class MixerUpdateBlock(nn.Module):
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128):
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
super(MixerUpdateBlock, self).__init__()
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim)
self.compressor = conv1x1(2*cdim+hdim, hdim)
self.refine = []
for i in range(num_blocks):
self.refine.append(CNBlock1d(hdim, hdim, use_mixer=True))
self.refine.append(CNBlock2d(hdim, hdim))
self.refine = nn.ModuleList(self.refine)
def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True):
BS,C,H,W = flowfeat.shape
B = BS//S
# with torch.no_grad():
motion_features = self.encoder(flow, corr)
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1))
for ii, blk in enumerate(self.refine):
flowfeat = blk(flowfeat, S)
return flowfeat
class FacUpdateBlock(nn.Module):
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=84, use_attn=False):
super(FacUpdateBlock, self).__init__()
self.corr_encoder = conv133_encoder(corr_channel, cdim)
# note we have hdim==cdim
# compressor chans:
# dim for flowfeat
# dim for ctxfeat
# dim for corr
# pdim for flow
# 2 for visconf
self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim)
self.refine = []
for i in range(num_blocks):
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn))
self.refine.append(CNBlock2d(hdim, hdim))
self.refine = nn.ModuleList(self.refine)
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
BS,C,H,W = flowfeat.shape
B = BS//S
corr = self.corr_encoder(corr)
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corr, visconf, flow], dim=1))
for blk in self.refine:
flowfeat = blk(flowfeat, S)
return flowfeat
class CleanUpdateBlock(nn.Module):
def __init__(self, corr_channel, num_blocks, cdim=128, hdim=256, pdim=84, use_attn=False, use_layer_scale=True):
super(CleanUpdateBlock, self).__init__()
self.corr_encoder = conv133_encoder(corr_channel, cdim)
# compressor chans:
# cdim for flowfeat
# cdim for ctxfeat
# cdim for corrfeat
# pdim for flow
# 2 for visconf
self.compressor = conv1x1(3*cdim+pdim+2, hdim)
self.refine = []
for i in range(num_blocks):
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_layer_scale=use_layer_scale))
self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale))
self.refine = nn.ModuleList(self.refine)
self.final_conv = conv1x1(hdim, cdim)
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
BS,C,H,W = flowfeat.shape
B = BS//S
corrfeat = self.corr_encoder(corr)
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corrfeat, flow, visconf], dim=1))
for blk in self.refine:
flowfeat = blk(flowfeat, S)
flowfeat = self.final_conv(flowfeat)
return flowfeat
class RelUpdateBlock(nn.Module):
def __init__(self, corr_channel, num_blocks, cdim=128, hdim=128, pdim=4, use_attn=True, use_mixer=False, use_conv=False, use_convb=False, use_layer_scale=True, no_time=False, no_space=False, no_ctx=False):
super(RelUpdateBlock, self).__init__()
self.motion_encoder = BasicMotionEncoder(corr_channel, dim=hdim, pdim=pdim) # B,hdim,H,W
self.no_ctx = no_ctx
if no_ctx:
self.compressor = conv1x1(cdim+hdim+2, hdim)
else:
self.compressor = conv1x1(2*cdim+hdim+2, hdim)
self.refine = []
for i in range(num_blocks):
if not no_time:
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_mixer=use_mixer, use_conv=use_conv, use_convb=use_convb, use_layer_scale=use_layer_scale))
if not no_space:
self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale))
self.refine = nn.ModuleList(self.refine)
self.final_conv = conv1x1(hdim, cdim)
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
BS,C,H,W = flowfeat.shape
B = BS//S
motion_features = self.motion_encoder(flow, corr)
if self.no_ctx:
flowfeat = self.compressor(torch.cat([flowfeat, motion_features, visconf], dim=1))
else:
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1))
for blk in self.refine:
flowfeat = blk(flowfeat, S)
flowfeat = self.final_conv(flowfeat)
return flowfeat