|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from functools import partial |
|
from typing import Callable |
|
import collections |
|
from torch import Tensor |
|
from itertools import repeat |
|
|
|
|
|
|
|
def _ntuple(n): |
|
def parse(x): |
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): |
|
return tuple(x) |
|
return tuple(repeat(x, n)) |
|
|
|
return parse |
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
|
|
to_2tuple = _ntuple(2) |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
""" |
|
ResidualBlock: construct a block of two conv layers with residual connections |
|
""" |
|
|
|
def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): |
|
super(ResidualBlock, self).__init__() |
|
|
|
self.conv1 = nn.Conv2d( |
|
in_planes, |
|
planes, |
|
kernel_size=kernel_size, |
|
padding=1, |
|
stride=stride, |
|
padding_mode="zeros", |
|
) |
|
self.conv2 = nn.Conv2d( |
|
planes, |
|
planes, |
|
kernel_size=kernel_size, |
|
padding=1, |
|
padding_mode="zeros", |
|
) |
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
num_groups = planes // 8 |
|
|
|
if norm_fn == "group": |
|
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
|
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
|
if not stride == 1: |
|
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
|
|
|
elif norm_fn == "batch": |
|
self.norm1 = nn.BatchNorm2d(planes) |
|
self.norm2 = nn.BatchNorm2d(planes) |
|
if not stride == 1: |
|
self.norm3 = nn.BatchNorm2d(planes) |
|
|
|
elif norm_fn == "instance": |
|
self.norm1 = nn.InstanceNorm2d(planes) |
|
self.norm2 = nn.InstanceNorm2d(planes) |
|
if not stride == 1: |
|
self.norm3 = nn.InstanceNorm2d(planes) |
|
|
|
elif norm_fn == "none": |
|
self.norm1 = nn.Sequential() |
|
self.norm2 = nn.Sequential() |
|
if not stride == 1: |
|
self.norm3 = nn.Sequential() |
|
else: |
|
raise NotImplementedError |
|
|
|
if stride == 1: |
|
self.downsample = None |
|
else: |
|
self.downsample = nn.Sequential( |
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), |
|
self.norm3, |
|
) |
|
|
|
def forward(self, x): |
|
y = x |
|
y = self.relu(self.norm1(self.conv1(y))) |
|
y = self.relu(self.norm2(self.conv2(y))) |
|
|
|
if self.downsample is not None: |
|
x = self.downsample(x) |
|
|
|
return self.relu(x + y) |
|
|
|
|
|
class Mlp(nn.Module): |
|
"""MLP as used in Vision Transformer, MLP-Mixer and related networks""" |
|
|
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.GELU, |
|
norm_layer=None, |
|
bias=True, |
|
drop=0.0, |
|
use_conv=False, |
|
): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
bias = to_2tuple(bias) |
|
drop_probs = to_2tuple(drop) |
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear |
|
|
|
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) |
|
self.act = act_layer() |
|
self.drop1 = nn.Dropout(drop_probs[0]) |
|
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) |
|
self.drop2 = nn.Dropout(drop_probs[1]) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.drop1(x) |
|
x = self.fc2(x) |
|
x = self.drop2(x) |
|
return x |
|
|
|
|
|
class AttnBlock(nn.Module): |
|
def __init__( |
|
self, |
|
hidden_size, |
|
num_heads, |
|
attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, |
|
mlp_ratio=4.0, |
|
**block_kwargs |
|
): |
|
""" |
|
Self attention block |
|
""" |
|
super().__init__() |
|
|
|
self.norm1 = nn.LayerNorm(hidden_size) |
|
self.norm2 = nn.LayerNorm(hidden_size) |
|
|
|
self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) |
|
|
|
mlp_hidden_dim = int(hidden_size * mlp_ratio) |
|
|
|
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) |
|
|
|
def forward(self, x, mask=None): |
|
|
|
|
|
|
|
x = self.norm1(x) |
|
|
|
|
|
|
|
|
|
attn_output, _ = self.attn(x, x, x) |
|
|
|
|
|
x = x + attn_output |
|
x = x + self.mlp(self.norm2(x)) |
|
return x |
|
|
|
|
|
class CrossAttnBlock(nn.Module): |
|
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): |
|
""" |
|
Cross attention block |
|
""" |
|
super().__init__() |
|
|
|
self.norm1 = nn.LayerNorm(hidden_size) |
|
self.norm_context = nn.LayerNorm(hidden_size) |
|
self.norm2 = nn.LayerNorm(hidden_size) |
|
|
|
self.cross_attn = nn.MultiheadAttention( |
|
embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs |
|
) |
|
|
|
mlp_hidden_dim = int(hidden_size * mlp_ratio) |
|
|
|
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) |
|
|
|
def forward(self, x, context, mask=None): |
|
|
|
x = self.norm1(x) |
|
context = self.norm_context(context) |
|
|
|
|
|
|
|
attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) |
|
|
|
|
|
x = x + attn_output |
|
x = x + self.mlp(self.norm2(x)) |
|
return x |
|
|