Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import nn, Tensor | |
from itertools import repeat | |
import collections | |
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence | |
from functools import partial | |
import einops | |
import math | |
from torchvision.ops.misc import Conv2dNormActivation, Permute | |
from torchvision.ops.stochastic_depth import StochasticDepth | |
def _ntuple(n): | |
def parse(x): | |
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): | |
return tuple(x) | |
return tuple(repeat(x, n)) | |
return parse | |
def exists(val): | |
return val is not None | |
def default(val, d): | |
return val if exists(val) else d | |
to_2tuple = _ntuple(2) | |
class InputPadder: | |
""" Pads images such that dimensions are divisible by a certain stride """ | |
def __init__(self, dims, mode='sintel'): | |
self.ht, self.wd = dims[-2:] | |
pad_ht = (((self.ht // 64) + 1) * 64 - self.ht) % 64 | |
pad_wd = (((self.wd // 64) + 1) * 64 - self.wd) % 64 | |
if mode == 'sintel': | |
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] | |
else: | |
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] | |
def pad(self, *inputs): | |
return [F.pad(x, self._pad, mode='replicate') for x in inputs] | |
def unpad(self, x): | |
ht, wd = x.shape[-2:] | |
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] | |
return x[..., c[0]:c[1], c[2]:c[3]] | |
def bilinear_sampler( | |
input, coords, | |
align_corners=True, | |
padding_mode="border", | |
normalize_coords=True): | |
# func from mattie (oct9) | |
if input.ndim not in [4, 5]: | |
raise ValueError("input must be 4D or 5D.") | |
if input.ndim == 4 and not coords.ndim == 4: | |
raise ValueError("input is 4D, but coords is not 4D.") | |
if input.ndim == 5 and not coords.ndim == 5: | |
raise ValueError("input is 5D, but coords is not 5D.") | |
if coords.ndim == 5: | |
coords = coords[..., [1, 2, 0]] # t x y -> x y t to match what grid_sample() expects. | |
if normalize_coords: | |
if align_corners: | |
# Normalize coordinates from [0, W/H - 1] to [-1, 1]. | |
coords = ( | |
coords | |
* torch.tensor([2 / max(size - 1, 1) for size in reversed(input.shape[2:])], device=coords.device) | |
- 1 | |
) | |
else: | |
# Normalize coordinates from [0, W/H] to [-1, 1]. | |
coords = coords * torch.tensor([2 / size for size in reversed(input.shape[2:])], device=coords.device) - 1 | |
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) | |
class CorrBlock: | |
def __init__(self, fmap1, fmap2, corr_levels, corr_radius): | |
self.num_levels = corr_levels | |
self.radius = corr_radius | |
self.corr_pyramid = [] | |
# all pairs correlation | |
for i in range(self.num_levels): | |
corr = CorrBlock.corr(fmap1, fmap2, 1) | |
batch, h1, w1, dim, h2, w2 = corr.shape | |
corr = corr.reshape(batch*h1*w1, dim, h2, w2) | |
fmap2 = F.interpolate(fmap2, scale_factor=0.5, mode='area') | |
# print('corr', corr.shape) | |
self.corr_pyramid.append(corr) | |
def __call__(self, coords, dilation=None): | |
r = self.radius | |
coords = coords.permute(0, 2, 3, 1) | |
batch, h1, w1, _ = coords.shape | |
if dilation is None: | |
dilation = torch.ones(batch, 1, h1, w1, device=coords.device) | |
out_pyramid = [] | |
for i in range(self.num_levels): | |
corr = self.corr_pyramid[i] | |
device = coords.device | |
dx = torch.linspace(-r, r, 2*r+1, device=device) | |
dy = torch.linspace(-r, r, 2*r+1, device=device) | |
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) | |
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) | |
delta_lvl = delta_lvl * dilation.view(batch * h1 * w1, 1, 1, 1) | |
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i | |
coords_lvl = centroid_lvl + delta_lvl | |
corr = bilinear_sampler(corr, coords_lvl) | |
corr = corr.view(batch, h1, w1, -1) | |
out_pyramid.append(corr) | |
out = torch.cat(out_pyramid, dim=-1) | |
out = out.permute(0, 3, 1, 2).contiguous().float() | |
return out | |
def corr(fmap1, fmap2, num_head): | |
batch, dim, h1, w1 = fmap1.shape | |
h2, w2 = fmap2.shape[2:] | |
fmap1 = fmap1.view(batch, num_head, dim // num_head, h1*w1) | |
fmap2 = fmap2.view(batch, num_head, dim // num_head, h2*w2) | |
corr = fmap1.transpose(2, 3) @ fmap2 | |
corr = corr.reshape(batch, num_head, h1, w1, h2, w2).permute(0, 2, 3, 1, 4, 5) | |
return corr / torch.sqrt(torch.tensor(dim).float()) | |
def conv1x1(in_planes, out_planes, stride=1): | |
"""1x1 convolution without padding""" | |
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) | |
def conv3x3(in_planes, out_planes, stride=1): | |
"""3x3 convolution with padding""" | |
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) | |
class LayerNorm2d(nn.LayerNorm): | |
def forward(self, x: Tensor) -> Tensor: | |
x = x.permute(0, 2, 3, 1) | |
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
x = x.permute(0, 3, 1, 2) | |
return x | |
class CNBlock1d(nn.Module): | |
def __init__( | |
self, | |
dim, | |
output_dim, | |
layer_scale: float = 1e-6, | |
stochastic_depth_prob: float = 0, | |
norm_layer: Optional[Callable[..., nn.Module]] = None, | |
dense=True, | |
use_attn=True, | |
use_mixer=False, | |
use_conv=False, | |
use_convb=False, | |
use_layer_scale=True, | |
) -> None: | |
super().__init__() | |
self.dense = dense | |
self.use_attn = use_attn | |
self.use_mixer = use_mixer | |
self.use_conv = use_conv | |
self.use_layer_scale = use_layer_scale | |
if use_attn: | |
assert not use_mixer | |
assert not use_conv | |
assert not use_convb | |
if norm_layer is None: | |
norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
if use_attn: | |
num_heads = 8 | |
self.block = AttnBlock( | |
hidden_size=dim, | |
num_heads=num_heads, | |
mlp_ratio=4, | |
attn_class=Attention, | |
) | |
elif use_mixer: | |
self.block = MLPMixerBlock( | |
S=16, | |
dim=dim, | |
depth=1, | |
expansion_factor=2, | |
) | |
elif use_conv: | |
self.block = nn.Sequential( | |
nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'), | |
Permute([0, 2, 1]), | |
norm_layer(dim), | |
nn.Linear(in_features=dim, out_features=4 * dim, bias=True), | |
nn.GELU(), | |
nn.Linear(in_features=4 * dim, out_features=dim, bias=True), | |
Permute([0, 2, 1]), | |
) | |
elif use_convb: | |
self.block = nn.Sequential( | |
nn.Conv1d(dim, dim, kernel_size=3, padding=1, bias=True, padding_mode='zeros'), | |
Permute([0, 2, 1]), | |
norm_layer(dim), | |
nn.Linear(in_features=dim, out_features=4 * dim, bias=True), | |
nn.GELU(), | |
nn.Linear(in_features=4 * dim, out_features=dim, bias=True), | |
Permute([0, 2, 1]), | |
) | |
else: | |
assert(False) # choose attn, mixer, or conv please | |
if self.use_layer_scale: | |
self.layer_scale = nn.Parameter(torch.ones(dim, 1) * layer_scale) | |
else: | |
self.layer_scale = 1.0 | |
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") | |
if output_dim != dim: | |
self.final = nn.Conv1d(dim, output_dim, kernel_size=1, padding=0) | |
else: | |
self.final = nn.Identity() | |
def forward(self, input, S=None): | |
if self.dense: | |
assert S is not None | |
BS,C,H,W = input.shape | |
B = BS//S | |
input = einops.rearrange(input, '(b s) c h w -> (b h w) c s', b=B, s=S, c=C, h=H, w=W) | |
if self.use_mixer or self.use_attn: | |
# mixer/transformer blocks want B,S,C | |
result = self.layer_scale * self.block(input.permute(0,2,1)).permute(0,2,1) | |
else: | |
result = self.layer_scale * self.block(input) | |
result = self.stochastic_depth(result) | |
result += input | |
result = self.final(result) | |
result = einops.rearrange(result, '(b h w) c s -> (b s) c h w', b=B, s=S, c=C, h=H, w=W) | |
else: | |
B,S,C = input.shape | |
if S<7: | |
return input | |
input = einops.rearrange(input, 'b s c -> b c s', b=B, s=S, c=C) | |
result = self.layer_scale * self.block(input) | |
result = self.stochastic_depth(result) | |
result += input | |
result = self.final(result) | |
result = einops.rearrange(result, 'b c s -> b s c', b=B, s=S, c=C) | |
return result | |
class CNBlock2d(nn.Module): | |
def __init__( | |
self, | |
dim, | |
output_dim, | |
layer_scale: float = 1e-6, | |
stochastic_depth_prob: float = 0, | |
norm_layer: Optional[Callable[..., nn.Module]] = None, | |
use_layer_scale=True, | |
) -> None: | |
super().__init__() | |
self.use_layer_scale = use_layer_scale | |
if norm_layer is None: | |
norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
self.block = nn.Sequential( | |
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'), | |
Permute([0, 2, 3, 1]), | |
norm_layer(dim), | |
nn.Linear(in_features=dim, out_features=4 * dim, bias=True), | |
nn.GELU(), | |
nn.Linear(in_features=4 * dim, out_features=dim, bias=True), | |
Permute([0, 3, 1, 2]), | |
) | |
if self.use_layer_scale: | |
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) | |
else: | |
self.layer_scale = 1.0 | |
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") | |
if output_dim != dim: | |
self.final = nn.Conv2d(dim, output_dim, kernel_size=1, padding=0) | |
else: | |
self.final = nn.Identity() | |
def forward(self, input, S=None): | |
result = self.layer_scale * self.block(input) | |
result = self.stochastic_depth(result) | |
result += input | |
result = self.final(result) | |
return result | |
class CNBlockConfig: | |
# Stores information listed at Section 3 of the ConvNeXt paper | |
def __init__( | |
self, | |
input_channels: int, | |
out_channels: Optional[int], | |
num_layers: int, | |
downsample: bool, | |
) -> None: | |
self.input_channels = input_channels | |
self.out_channels = out_channels | |
self.num_layers = num_layers | |
self.downsample = downsample | |
def __repr__(self) -> str: | |
s = self.__class__.__name__ + "(" | |
s += "input_channels={input_channels}" | |
s += ", out_channels={out_channels}" | |
s += ", num_layers={num_layers}" | |
s += ", downsample={downsample}" | |
s += ")" | |
return s.format(**self.__dict__) | |
class ConvNeXt(nn.Module): | |
def __init__( | |
self, | |
block_setting: List[CNBlockConfig], | |
stochastic_depth_prob: float = 0.0, | |
layer_scale: float = 1e-6, | |
num_classes: int = 1000, | |
block: Optional[Callable[..., nn.Module]] = None, | |
norm_layer: Optional[Callable[..., nn.Module]] = None, | |
init_weights=True): | |
super().__init__() | |
self.init_weights = init_weights | |
if not block_setting: | |
raise ValueError("The block_setting should not be empty") | |
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])): | |
raise TypeError("The block_setting should be List[CNBlockConfig]") | |
if block is None: | |
block = CNBlock2d | |
if norm_layer is None: | |
norm_layer = partial(LayerNorm2d, eps=1e-6) | |
layers: List[nn.Module] = [] | |
# Stem | |
firstconv_output_channels = block_setting[0].input_channels | |
layers.append( | |
Conv2dNormActivation( | |
3, | |
firstconv_output_channels, | |
kernel_size=4, | |
stride=4, | |
padding=0, | |
norm_layer=norm_layer, | |
activation_layer=None, | |
bias=True, | |
) | |
) | |
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting) | |
stage_block_id = 0 | |
for cnf in block_setting: | |
# Bottlenecks | |
stage: List[nn.Module] = [] | |
for _ in range(cnf.num_layers): | |
# adjust stochastic depth probability based on the depth of the stage block | |
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) | |
stage.append(block(cnf.input_channels, cnf.input_channels, layer_scale, sd_prob)) | |
stage_block_id += 1 | |
layers.append(nn.Sequential(*stage)) | |
if cnf.out_channels is not None: | |
if cnf.downsample: | |
layers.append( | |
nn.Sequential( | |
norm_layer(cnf.input_channels), | |
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2), | |
) | |
) | |
else: | |
# we convert the 2x2 downsampling layer into a 3x3 with dilation2 and replicate padding. | |
# replicate padding compensates for the fact that this kernel never saw zero-padding. | |
layers.append( | |
nn.Sequential( | |
norm_layer(cnf.input_channels), | |
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=3, stride=1, padding=2, dilation=2, padding_mode='zeros'), | |
) | |
) | |
self.features = nn.Sequential(*layers) | |
# self.final_conv = conv1x1(block_setting[-1].input_channels, output_dim) | |
for m in self.modules(): | |
if isinstance(m, (nn.Conv2d, nn.Linear)): | |
nn.init.trunc_normal_(m.weight, std=0.02) | |
if m.bias is not None: | |
nn.init.zeros_(m.bias) | |
if self.init_weights: | |
from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights | |
pretrained_dict = convnext_tiny(weights=ConvNeXt_Tiny_Weights.DEFAULT).state_dict() | |
# from torchvision.models import convnext_base, ConvNeXt_Base_Weights | |
# pretrained_dict = convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT).state_dict() | |
model_dict = self.state_dict() | |
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} | |
for k, v in pretrained_dict.items(): | |
if k == 'features.4.1.weight': # this is the layer normally in charge of 2x2 downsampling | |
# convert to 3x3 filter | |
pretrained_dict[k] = F.interpolate(v, (3, 3), mode='bicubic', align_corners=True) * (4/9.0) | |
model_dict.update(pretrained_dict) | |
self.load_state_dict(model_dict, strict=False) | |
def _forward_impl(self, x: Tensor) -> Tensor: | |
x = self.features(x) | |
# x = self.final_conv(x) | |
return x | |
def forward(self, x: Tensor) -> Tensor: | |
return self._forward_impl(x) | |
class Mlp(nn.Module): | |
"""MLP as used in Vision Transformer, MLP-Mixer and related networks""" | |
def __init__( | |
self, | |
in_features, | |
hidden_features=None, | |
out_features=None, | |
act_layer=nn.GELU, | |
norm_layer=None, | |
bias=True, | |
drop=0.0, | |
use_conv=False, | |
): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
bias = to_2tuple(bias) | |
drop_probs = to_2tuple(drop) | |
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear | |
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) | |
self.act = act_layer() | |
self.drop1 = nn.Dropout(drop_probs[0]) | |
self.norm = ( | |
norm_layer(hidden_features) if norm_layer is not None else nn.Identity() | |
) | |
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) | |
self.drop2 = nn.Dropout(drop_probs[1]) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop1(x) | |
x = self.fc2(x) | |
x = self.drop2(x) | |
return x | |
class Attention(nn.Module): | |
def __init__( | |
self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False | |
): | |
super().__init__() | |
inner_dim = dim_head * num_heads | |
context_dim = default(context_dim, query_dim) | |
self.scale = dim_head**-0.5 | |
self.heads = num_heads | |
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias) | |
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias) | |
self.to_out = nn.Linear(inner_dim, query_dim) | |
def forward(self, x, context=None, attn_bias=None): | |
B, N1, C = x.shape | |
H = self.heads | |
q = self.to_q(x) | |
context = default(context, x) | |
k, v = self.to_kv(context).chunk(2, dim=-1) | |
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v)) | |
x = F.scaled_dot_product_attention(q, k, v) # scale default is already dim^-0.5 | |
x = einops.rearrange(x, 'b h n d -> b n (h d)') | |
return self.to_out(x) | |
class CrossAttnBlock(nn.Module): | |
def __init__( | |
self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs | |
): | |
super().__init__() | |
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
self.norm_context = nn.LayerNorm(hidden_size) | |
self.cross_attn = Attention( | |
hidden_size, | |
context_dim=context_dim, | |
num_heads=num_heads, | |
qkv_bias=True, | |
**block_kwargs | |
) | |
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
approx_gelu = lambda: nn.GELU(approximate="tanh") | |
self.mlp = Mlp( | |
in_features=hidden_size, | |
hidden_features=mlp_hidden_dim, | |
act_layer=approx_gelu, | |
drop=0, | |
) | |
def forward(self, x, context, mask=None): | |
attn_bias = None | |
if mask is not None: | |
if mask.shape[1] == x.shape[1]: | |
mask = mask[:, None, :, None].expand( | |
-1, self.cross_attn.heads, -1, context.shape[1] | |
) | |
else: | |
mask = mask[:, None, None].expand( | |
-1, self.cross_attn.heads, x.shape[1], -1 | |
) | |
max_neg_value = -torch.finfo(x.dtype).max | |
attn_bias = (~mask) * max_neg_value | |
x = x + self.cross_attn( | |
self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias | |
) | |
x = x + self.mlp(self.norm2(x)) | |
return x | |
class AttnBlock(nn.Module): | |
def __init__( | |
self, | |
hidden_size, | |
num_heads, | |
attn_class: Callable[..., nn.Module] = Attention, | |
mlp_ratio=4.0, | |
**block_kwargs | |
): | |
super().__init__() | |
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, dim_head=hidden_size//num_heads) | |
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
approx_gelu = lambda: nn.GELU(approximate="tanh") | |
self.mlp = Mlp( | |
in_features=hidden_size, | |
hidden_features=mlp_hidden_dim, | |
act_layer=approx_gelu, | |
drop=0, | |
) | |
def forward(self, x, mask=None): | |
attn_bias = mask | |
if mask is not None: | |
mask = ( | |
(mask[:, None] * mask[:, :, None]) | |
.unsqueeze(1) | |
.expand(-1, self.attn.num_heads, -1, -1) | |
) | |
max_neg_value = -torch.finfo(x.dtype).max | |
attn_bias = (~mask) * max_neg_value | |
x = x + self.attn(self.norm1(x), attn_bias=attn_bias) | |
x = x + self.mlp(self.norm2(x)) | |
return x | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_planes, planes, norm_fn="group", stride=1): | |
super(ResidualBlock, self).__init__() | |
self.conv1 = nn.Conv2d( | |
in_planes, | |
planes, | |
kernel_size=3, | |
padding=1, | |
stride=stride, | |
padding_mode="zeros", | |
) | |
self.conv2 = nn.Conv2d( | |
planes, planes, kernel_size=3, padding=1, padding_mode="zeros" | |
) | |
self.relu = nn.ReLU(inplace=True) | |
num_groups = planes // 8 | |
if norm_fn == "group": | |
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | |
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | |
if not stride == 1: | |
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | |
elif norm_fn == "batch": | |
self.norm1 = nn.BatchNorm2d(planes) | |
self.norm2 = nn.BatchNorm2d(planes) | |
if not stride == 1: | |
self.norm3 = nn.BatchNorm2d(planes) | |
elif norm_fn == "instance": | |
self.norm1 = nn.InstanceNorm2d(planes) | |
self.norm2 = nn.InstanceNorm2d(planes) | |
if not stride == 1: | |
self.norm3 = nn.InstanceNorm2d(planes) | |
elif norm_fn == "none": | |
self.norm1 = nn.Sequential() | |
self.norm2 = nn.Sequential() | |
if not stride == 1: | |
self.norm3 = nn.Sequential() | |
if stride == 1: | |
self.downsample = None | |
else: | |
self.downsample = nn.Sequential( | |
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 | |
) | |
def forward(self, x): | |
y = x | |
y = self.relu(self.norm1(self.conv1(y))) | |
y = self.relu(self.norm2(self.conv2(y))) | |
if self.downsample is not None: | |
x = self.downsample(x) | |
return self.relu(x + y) | |
class BasicEncoder(nn.Module): | |
def __init__(self, input_dim=3, output_dim=128, stride=4): | |
super(BasicEncoder, self).__init__() | |
self.stride = stride | |
self.norm_fn = "instance" | |
self.in_planes = output_dim // 2 | |
self.norm1 = nn.InstanceNorm2d(self.in_planes) | |
self.norm2 = nn.InstanceNorm2d(output_dim * 2) | |
self.conv1 = nn.Conv2d( | |
input_dim, | |
self.in_planes, | |
kernel_size=7, | |
stride=2, | |
padding=3, | |
padding_mode="zeros", | |
) | |
self.relu1 = nn.ReLU(inplace=True) | |
self.layer1 = self._make_layer(output_dim // 2, stride=1) | |
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) | |
self.layer3 = self._make_layer(output_dim, stride=2) | |
self.layer4 = self._make_layer(output_dim, stride=2) | |
self.conv2 = nn.Conv2d( | |
output_dim * 3 + output_dim // 4, | |
output_dim * 2, | |
kernel_size=3, | |
padding=1, | |
padding_mode="zeros", | |
) | |
self.relu2 = nn.ReLU(inplace=True) | |
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
elif isinstance(m, (nn.InstanceNorm2d)): | |
if m.weight is not None: | |
nn.init.constant_(m.weight, 1) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
def _make_layer(self, dim, stride=1): | |
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) | |
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) | |
layers = (layer1, layer2) | |
self.in_planes = dim | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
_, _, H, W = x.shape | |
x = self.conv1(x) | |
x = self.norm1(x) | |
x = self.relu1(x) | |
a = self.layer1(x) | |
b = self.layer2(a) | |
c = self.layer3(b) | |
d = self.layer4(c) | |
def _bilinear_intepolate(x): | |
return F.interpolate( | |
x, | |
(H // self.stride, W // self.stride), | |
mode="bilinear", | |
align_corners=True, | |
) | |
a = _bilinear_intepolate(a) | |
b = _bilinear_intepolate(b) | |
c = _bilinear_intepolate(c) | |
d = _bilinear_intepolate(d) | |
x = self.conv2(torch.cat([a, b, c, d], dim=1)) | |
x = self.norm2(x) | |
x = self.relu2(x) | |
x = self.conv3(x) | |
return x | |
class EfficientUpdateFormer(nn.Module): | |
""" | |
Transformer model that updates track estimates. | |
""" | |
def __init__( | |
self, | |
space_depth=6, | |
time_depth=6, | |
input_dim=320, | |
hidden_size=384, | |
num_heads=8, | |
output_dim=130, | |
mlp_ratio=4.0, | |
num_virtual_tracks=64, | |
add_space_attn=True, | |
linear_layer_for_vis_conf=False, | |
use_time_conv=False, | |
use_time_mixer=False, | |
): | |
super().__init__() | |
self.out_channels = 2 | |
self.num_heads = num_heads | |
self.hidden_size = hidden_size | |
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) | |
if linear_layer_for_vis_conf: | |
self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True) | |
self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True) | |
else: | |
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) | |
self.num_virtual_tracks = num_virtual_tracks | |
self.virual_tracks = nn.Parameter( | |
torch.randn(1, num_virtual_tracks, 1, hidden_size) | |
) | |
self.add_space_attn = add_space_attn | |
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf | |
if use_time_conv: | |
self.time_blocks = nn.ModuleList( | |
[ | |
CNBlock1d(hidden_size, hidden_size, dense=False) | |
for _ in range(time_depth) | |
] | |
) | |
elif use_time_mixer: | |
self.time_blocks = nn.ModuleList( | |
[ | |
MLPMixerBlock( | |
S=16, | |
dim=hidden_size, | |
depth=1, | |
) | |
for _ in range(time_depth) | |
] | |
) | |
else: | |
self.time_blocks = nn.ModuleList( | |
[ | |
AttnBlock( | |
hidden_size, | |
num_heads, | |
mlp_ratio=mlp_ratio, | |
attn_class=Attention, | |
) | |
for _ in range(time_depth) | |
] | |
) | |
if add_space_attn: | |
self.space_virtual_blocks = nn.ModuleList( | |
[ | |
AttnBlock( | |
hidden_size, | |
num_heads, | |
mlp_ratio=mlp_ratio, | |
attn_class=Attention, | |
) | |
for _ in range(space_depth) | |
] | |
) | |
self.space_point2virtual_blocks = nn.ModuleList( | |
[ | |
CrossAttnBlock( | |
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio | |
) | |
for _ in range(space_depth) | |
] | |
) | |
self.space_virtual2point_blocks = nn.ModuleList( | |
[ | |
CrossAttnBlock( | |
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio | |
) | |
for _ in range(space_depth) | |
] | |
) | |
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) | |
self.initialize_weights() | |
def initialize_weights(self): | |
def _basic_init(module): | |
if isinstance(module, nn.Linear): | |
torch.nn.init.xavier_uniform_(module.weight) | |
if module.bias is not None: | |
nn.init.constant_(module.bias, 0) | |
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) | |
if self.linear_layer_for_vis_conf: | |
torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001) | |
def _trunc_init(module): | |
"""ViT weight initialization, original timm impl (for reproducibility)""" | |
if isinstance(module, nn.Linear): | |
torch.nn.init.trunc_normal_(module.weight, std=0.02) | |
if module.bias is not None: | |
nn.init.zeros_(module.bias) | |
self.apply(_basic_init) | |
def forward(self, input_tensor, mask=None, add_space_attn=True): | |
tokens = self.input_transform(input_tensor) | |
B, _, T, _ = tokens.shape | |
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) | |
tokens = torch.cat([tokens, virtual_tokens], dim=1) | |
_, N, _, _ = tokens.shape | |
j = 0 | |
layers = [] | |
for i in range(len(self.time_blocks)): | |
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C | |
time_tokens = self.time_blocks[i](time_tokens) | |
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C | |
if ( | |
add_space_attn | |
and hasattr(self, "space_virtual_blocks") | |
and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0) | |
): | |
space_tokens = ( | |
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) | |
) # B N T C -> (B T) N C | |
point_tokens = space_tokens[:, : N - self.num_virtual_tracks] | |
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] | |
virtual_tokens = self.space_virtual2point_blocks[j]( | |
virtual_tokens, point_tokens, mask=mask | |
) | |
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) | |
point_tokens = self.space_point2virtual_blocks[j]( | |
point_tokens, virtual_tokens, mask=mask | |
) | |
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) | |
tokens = space_tokens.view(B, T, N, -1).permute( | |
0, 2, 1, 3 | |
) # (B T) N C -> B N T C | |
j += 1 | |
tokens = tokens[:, : N - self.num_virtual_tracks] | |
flow = self.flow_head(tokens) | |
if self.linear_layer_for_vis_conf: | |
vis_conf = self.vis_conf_head(tokens) | |
flow = torch.cat([flow, vis_conf], dim=-1) | |
return flow | |
class MMPreNormResidual(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.fn = fn | |
self.norm = nn.LayerNorm(dim) | |
def forward(self, x): | |
return self.fn(self.norm(x)) + x | |
def MMFeedForward(dim, expansion_factor=4, dropout=0., dense=nn.Linear): | |
return nn.Sequential( | |
dense(dim, dim * expansion_factor), | |
nn.GELU(), | |
nn.Dropout(dropout), | |
dense(dim * expansion_factor, dim), | |
nn.Dropout(dropout) | |
) | |
def MLPMixer(S, input_dim, dim, output_dim, depth=6, expansion_factor=4, dropout=0., do_reduce=False): | |
# input is coming in as B,S,C, as standard for mlp and transformer | |
# chan_first treats S as the channel dim, and transforms it to a new S | |
# chan_last treats C as the channel dim, and transforms it to a new C | |
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear | |
if do_reduce: | |
return nn.Sequential( | |
nn.Linear(input_dim, dim), | |
*[nn.Sequential( | |
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)), | |
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last)) | |
) for _ in range(depth)], | |
nn.LayerNorm(dim), | |
Reduce('b n c -> b c', 'mean'), | |
nn.Linear(dim, output_dim) | |
) | |
else: | |
return nn.Sequential( | |
nn.Linear(input_dim, dim), | |
*[nn.Sequential( | |
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)), | |
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last)) | |
) for _ in range(depth)], | |
) | |
def MLPMixerBlock(S, dim, depth=1, expansion_factor=4, dropout=0., do_reduce=False): | |
# input is coming in as B,S,C, as standard for mlp and transformer | |
# chan_first treats S as the channel dim, and transforms it to a new S | |
# chan_last treats C as the channel dim, and transforms it to a new C | |
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear | |
return nn.Sequential( | |
*[nn.Sequential( | |
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)), | |
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last)) | |
) for _ in range(depth)], | |
) | |
class MlpUpdateFormer(nn.Module): | |
""" | |
Transformer model that updates track estimates. | |
""" | |
def __init__( | |
self, | |
space_depth=6, | |
time_depth=6, | |
input_dim=320, | |
hidden_size=384, | |
num_heads=8, | |
output_dim=130, | |
mlp_ratio=4.0, | |
num_virtual_tracks=64, | |
add_space_attn=True, | |
linear_layer_for_vis_conf=False, | |
): | |
super().__init__() | |
self.out_channels = 2 | |
self.num_heads = num_heads | |
self.hidden_size = hidden_size | |
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) | |
if linear_layer_for_vis_conf: | |
self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True) | |
self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True) | |
else: | |
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) | |
self.num_virtual_tracks = num_virtual_tracks | |
self.virual_tracks = nn.Parameter( | |
torch.randn(1, num_virtual_tracks, 1, hidden_size) | |
) | |
self.add_space_attn = add_space_attn | |
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf | |
self.time_blocks = nn.ModuleList( | |
[ | |
MLPMixer( | |
S=16, | |
input_dim=hidden_size, | |
dim=hidden_size, | |
output_dim=hidden_size, | |
depth=1, | |
) | |
for _ in range(time_depth) | |
] | |
) | |
if add_space_attn: | |
self.space_virtual_blocks = nn.ModuleList( | |
[ | |
AttnBlock( | |
hidden_size, | |
num_heads, | |
mlp_ratio=mlp_ratio, | |
attn_class=Attention, | |
) | |
for _ in range(space_depth) | |
] | |
) | |
self.space_point2virtual_blocks = nn.ModuleList( | |
[ | |
CrossAttnBlock( | |
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio | |
) | |
for _ in range(space_depth) | |
] | |
) | |
self.space_virtual2point_blocks = nn.ModuleList( | |
[ | |
CrossAttnBlock( | |
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio | |
) | |
for _ in range(space_depth) | |
] | |
) | |
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) | |
self.initialize_weights() | |
def initialize_weights(self): | |
def _basic_init(module): | |
if isinstance(module, nn.Linear): | |
torch.nn.init.xavier_uniform_(module.weight) | |
if module.bias is not None: | |
nn.init.constant_(module.bias, 0) | |
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) | |
if self.linear_layer_for_vis_conf: | |
torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001) | |
def _trunc_init(module): | |
"""ViT weight initialization, original timm impl (for reproducibility)""" | |
if isinstance(module, nn.Linear): | |
torch.nn.init.trunc_normal_(module.weight, std=0.02) | |
if module.bias is not None: | |
nn.init.zeros_(module.bias) | |
self.apply(_basic_init) | |
def forward(self, input_tensor, mask=None, add_space_attn=True): | |
tokens = self.input_transform(input_tensor) | |
B, _, T, _ = tokens.shape | |
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) | |
tokens = torch.cat([tokens, virtual_tokens], dim=1) | |
_, N, _, _ = tokens.shape | |
j = 0 | |
layers = [] | |
for i in range(len(self.time_blocks)): | |
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C | |
time_tokens = self.time_blocks[i](time_tokens) | |
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C | |
if ( | |
add_space_attn | |
and hasattr(self, "space_virtual_blocks") | |
and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0) | |
): | |
space_tokens = ( | |
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) | |
) # B N T C -> (B T) N C | |
point_tokens = space_tokens[:, : N - self.num_virtual_tracks] | |
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] | |
virtual_tokens = self.space_virtual2point_blocks[j]( | |
virtual_tokens, point_tokens, mask=mask | |
) | |
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) | |
point_tokens = self.space_point2virtual_blocks[j]( | |
point_tokens, virtual_tokens, mask=mask | |
) | |
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) | |
tokens = space_tokens.view(B, T, N, -1).permute( | |
0, 2, 1, 3 | |
) # (B T) N C -> B N T C | |
j += 1 | |
tokens = tokens[:, : N - self.num_virtual_tracks] | |
flow = self.flow_head(tokens) | |
if self.linear_layer_for_vis_conf: | |
vis_conf = self.vis_conf_head(tokens) | |
flow = torch.cat([flow, vis_conf], dim=-1) | |
return flow | |
class BasicMotionEncoder(nn.Module): | |
def __init__(self, corr_channel, dim=128, pdim=2): | |
super(BasicMotionEncoder, self).__init__() | |
self.pdim = pdim | |
self.convc1 = nn.Conv2d(corr_channel, dim*4, 1, padding=0) | |
self.convc2 = nn.Conv2d(dim*4, dim+dim//2, 3, padding=1) | |
if pdim==2 or pdim==4: | |
self.convf1 = nn.Conv2d(pdim, dim*2, 5, padding=2) | |
self.convf2 = nn.Conv2d(dim*2, dim//2, 3, padding=1) | |
self.conv = nn.Conv2d(dim*2, dim-pdim, 3, padding=1) | |
else: | |
self.conv = nn.Conv2d(dim+dim//2+pdim, dim, 3, padding=1) | |
def forward(self, flow, corr): | |
cor = F.relu(self.convc1(corr)) | |
cor = F.relu(self.convc2(cor)) | |
if self.pdim==2 or self.pdim==4: | |
flo = F.relu(self.convf1(flow)) | |
flo = F.relu(self.convf2(flo)) | |
cor_flo = torch.cat([cor, flo], dim=1) | |
out = F.relu(self.conv(cor_flo)) | |
return torch.cat([out, flow], dim=1) | |
else: | |
# the flow is already encoded to something nice | |
cor_flo = torch.cat([cor, flow], dim=1) | |
return F.relu(self.conv(cor_flo)) | |
# return torch.cat([out, flow], dim=1) | |
def conv133_encoder(input_dim, dim, expansion_factor=4): | |
return nn.Sequential( | |
nn.Conv2d(input_dim, dim*expansion_factor, kernel_size=1), | |
nn.GELU(), | |
nn.Conv2d(dim*expansion_factor, dim*expansion_factor, kernel_size=3, padding=1), | |
nn.GELU(), | |
nn.Conv2d(dim*expansion_factor, dim, kernel_size=3, padding=1), | |
) | |
class BasicUpdateBlock(nn.Module): | |
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128): | |
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim. | |
super(BasicUpdateBlock, self).__init__() | |
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim) | |
self.compressor = conv1x1(2*cdim+hdim, hdim) | |
self.refine = [] | |
for i in range(num_blocks): | |
self.refine.append(CNBlock1d(hdim, hdim)) | |
self.refine.append(CNBlock2d(hdim, hdim)) | |
self.refine = nn.ModuleList(self.refine) | |
def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True): | |
BS,C,H,W = flowfeat.shape | |
B = BS//S | |
# with torch.no_grad(): | |
motion_features = self.encoder(flow, corr) | |
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1)) | |
for blk in self.refine: | |
flowfeat = blk(flowfeat, S) | |
return flowfeat | |
class FullUpdateBlock(nn.Module): | |
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=2, use_attn=False): | |
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim. | |
super(FullUpdateBlock, self).__init__() | |
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim, pdim=pdim) | |
# note we have hdim==cdim | |
# compressor chans: | |
# dim for flowfeat | |
# dim for ctxfeat | |
# dim for motion_features | |
# pdim for flow (if p 2, like if we give sincos(relflow)) | |
# 2 for visconf | |
if pdim==2: | |
# hdim==cdim | |
# dim for flowfeat | |
# dim for ctxfeat | |
# dim for motion_features | |
# 2 for visconf | |
self.compressor = conv1x1(2*cdim+hdim+2, hdim) | |
else: | |
# we concatenate the flow info again, to not lose it (e.g., from the relu) | |
self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim) | |
self.refine = [] | |
for i in range(num_blocks): | |
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn)) | |
self.refine.append(CNBlock2d(hdim, hdim)) | |
self.refine = nn.ModuleList(self.refine) | |
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True): | |
BS,C,H,W = flowfeat.shape | |
B = BS//S | |
motion_features = self.encoder(flow, corr) | |
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1)) | |
for blk in self.refine: | |
flowfeat = blk(flowfeat, S) | |
return flowfeat | |
class MixerUpdateBlock(nn.Module): | |
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128): | |
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim. | |
super(MixerUpdateBlock, self).__init__() | |
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim) | |
self.compressor = conv1x1(2*cdim+hdim, hdim) | |
self.refine = [] | |
for i in range(num_blocks): | |
self.refine.append(CNBlock1d(hdim, hdim, use_mixer=True)) | |
self.refine.append(CNBlock2d(hdim, hdim)) | |
self.refine = nn.ModuleList(self.refine) | |
def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True): | |
BS,C,H,W = flowfeat.shape | |
B = BS//S | |
# with torch.no_grad(): | |
motion_features = self.encoder(flow, corr) | |
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1)) | |
for ii, blk in enumerate(self.refine): | |
flowfeat = blk(flowfeat, S) | |
return flowfeat | |
class FacUpdateBlock(nn.Module): | |
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=84, use_attn=False): | |
super(FacUpdateBlock, self).__init__() | |
self.corr_encoder = conv133_encoder(corr_channel, cdim) | |
# note we have hdim==cdim | |
# compressor chans: | |
# dim for flowfeat | |
# dim for ctxfeat | |
# dim for corr | |
# pdim for flow | |
# 2 for visconf | |
self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim) | |
self.refine = [] | |
for i in range(num_blocks): | |
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn)) | |
self.refine.append(CNBlock2d(hdim, hdim)) | |
self.refine = nn.ModuleList(self.refine) | |
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True): | |
BS,C,H,W = flowfeat.shape | |
B = BS//S | |
corr = self.corr_encoder(corr) | |
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corr, visconf, flow], dim=1)) | |
for blk in self.refine: | |
flowfeat = blk(flowfeat, S) | |
return flowfeat | |
class CleanUpdateBlock(nn.Module): | |
def __init__(self, corr_channel, num_blocks, cdim=128, hdim=256, pdim=84, use_attn=False, use_layer_scale=True): | |
super(CleanUpdateBlock, self).__init__() | |
self.corr_encoder = conv133_encoder(corr_channel, cdim) | |
# compressor chans: | |
# cdim for flowfeat | |
# cdim for ctxfeat | |
# cdim for corrfeat | |
# pdim for flow | |
# 2 for visconf | |
self.compressor = conv1x1(3*cdim+pdim+2, hdim) | |
self.refine = [] | |
for i in range(num_blocks): | |
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_layer_scale=use_layer_scale)) | |
self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale)) | |
self.refine = nn.ModuleList(self.refine) | |
self.final_conv = conv1x1(hdim, cdim) | |
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True): | |
BS,C,H,W = flowfeat.shape | |
B = BS//S | |
corrfeat = self.corr_encoder(corr) | |
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corrfeat, flow, visconf], dim=1)) | |
for blk in self.refine: | |
flowfeat = blk(flowfeat, S) | |
flowfeat = self.final_conv(flowfeat) | |
return flowfeat | |
class RelUpdateBlock(nn.Module): | |
def __init__(self, corr_channel, num_blocks, cdim=128, hdim=128, pdim=4, use_attn=True, use_mixer=False, use_conv=False, use_convb=False, use_layer_scale=True, no_time=False, no_space=False, no_ctx=False): | |
super(RelUpdateBlock, self).__init__() | |
self.motion_encoder = BasicMotionEncoder(corr_channel, dim=hdim, pdim=pdim) # B,hdim,H,W | |
self.no_ctx = no_ctx | |
if no_ctx: | |
self.compressor = conv1x1(cdim+hdim+2, hdim) | |
else: | |
self.compressor = conv1x1(2*cdim+hdim+2, hdim) | |
self.refine = [] | |
for i in range(num_blocks): | |
if not no_time: | |
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_mixer=use_mixer, use_conv=use_conv, use_convb=use_convb, use_layer_scale=use_layer_scale)) | |
if not no_space: | |
self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale)) | |
self.refine = nn.ModuleList(self.refine) | |
self.final_conv = conv1x1(hdim, cdim) | |
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True): | |
BS,C,H,W = flowfeat.shape | |
B = BS//S | |
motion_features = self.motion_encoder(flow, corr) | |
if self.no_ctx: | |
flowfeat = self.compressor(torch.cat([flowfeat, motion_features, visconf], dim=1)) | |
else: | |
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1)) | |
for blk in self.refine: | |
flowfeat = blk(flowfeat, S) | |
flowfeat = self.final_conv(flowfeat) | |
return flowfeat | |