ghost233lism's picture
upload models
7f0f123 verified
'''
DFormerv2: Geometry Self-Attention for RGBD Semantic Segmentation
Code: https://github.com/VCIP-RGBD/DFormer
Author: yinbow
Email: [email protected]
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import math
from timm.models.layers import DropPath, trunc_normal_
from typing import List
from mmengine.runner.checkpoint import load_state_dict
from mmengine.runner.checkpoint import load_checkpoint
from typing import Tuple
import sys
import os
from collections import OrderedDict
class LayerNorm2d(nn.Module):
def __init__(self, dim):
super().__init__()
self.norm = nn.LayerNorm(dim, eps=1e-6)
def forward(self, x: torch.Tensor):
'''
input shape (b c h w)
'''
x = x.permute(0, 2, 3, 1).contiguous() #(b h w c)
x = self.norm(x) #(b h w c)
x = x.permute(0, 3, 1, 2).contiguous()
return x
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding
"""
def __init__(self, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Sequential(
nn.Conv2d(in_chans, embed_dim//2, 3, 2, 1),
nn.SyncBatchNorm(embed_dim//2),
nn.GELU(),
nn.Conv2d(embed_dim//2, embed_dim//2, 3, 1, 1),
nn.SyncBatchNorm(embed_dim//2),
nn.GELU(),
nn.Conv2d(embed_dim//2, embed_dim, 3, 2, 1),
nn.SyncBatchNorm(embed_dim),
nn.GELU(),
nn.Conv2d(embed_dim, embed_dim, 3, 1, 1),
nn.SyncBatchNorm(embed_dim)
)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).permute(0, 2, 3, 1)
return x
class DWConv2d(nn.Module):
def __init__(self, dim, kernel_size, stride, padding):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size, stride, padding, groups=dim)
def forward(self, x: torch.Tensor):
'''
input (b h w c)
'''
x = x.permute(0, 3, 1, 2)
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1)
return x
class PatchMerging(nn.Module):
"""
Patch Merging Layer.
"""
def __init__(self, dim, out_dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Conv2d(dim, out_dim, 3, 2, 1)
self.norm = nn.SyncBatchNorm(out_dim)
def forward(self, x):
'''
x: B H W C
'''
x = x.permute(0, 3, 1, 2).contiguous() #(b c h w)
x = self.reduction(x) #(b oc oh ow)
x = self.norm(x)
x = x.permute(0, 2, 3, 1) #(b oh ow oc)
return x
def angle_transform(x, sin, cos):
x1 = x[:, :, :, :, ::2]
x2 = x[:, :, :, :, 1::2]
return (x * cos) + (torch.stack([-x2, x1], dim=-1).flatten(-2) * sin)
class GeoPriorGen(nn.Module):
def __init__(self, embed_dim, num_heads, initial_value, heads_range):
super().__init__()
angle = 1.0 / (10000 ** torch.linspace(0, 1, embed_dim // num_heads // 2))
angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
self.weight = nn.Parameter(torch.ones(2,1,1,1), requires_grad=True)
decay = torch.log(1 - 2 ** (-initial_value - heads_range * torch.arange(num_heads, dtype=torch.float) / num_heads))
self.register_buffer('angle', angle)
self.register_buffer('decay', decay)
def generate_depth_decay(self, H: int, W: int, depth_grid):
'''
generate 2d decay mask, the result is (HW)*(HW)
H, W are the numbers of patches at each column and row
'''
B,_,H,W = depth_grid.shape
grid_d = depth_grid.reshape(B, H*W, 1)
mask_d = grid_d[:, :, None, :] - grid_d[:, None, :, :]
mask_d = (mask_d.abs()).sum(dim=-1)
mask_d = mask_d.unsqueeze(1) * self.decay[None, :, None, None]
return mask_d
def generate_pos_decay(self, H: int, W: int):
'''
generate 2d decay mask, the result is (HW)*(HW)
H, W are the numbers of patches at each column and row
'''
index_h = torch.arange(H).to(self.decay)
index_w = torch.arange(W).to(self.decay)
grid = torch.meshgrid([index_h, index_w])
grid = torch.stack(grid, dim=-1).reshape(H*W, 2)
mask = grid[:, None, :] - grid[None, :, :]
mask = (mask.abs()).sum(dim=-1)
mask = mask * self.decay[:, None, None]
return mask
def generate_1d_depth_decay(self, H, W, depth_grid):
'''
generate 1d depth decay mask, the result is l*l
'''
mask = depth_grid[:, :, :, :, None] - depth_grid[:, :, :, None, :]
mask = mask.abs()
mask = mask * self.decay[:, None, None, None]
assert mask.shape[2:] == (W,H,H)
return mask
def generate_1d_decay(self, l: int):
'''
generate 1d decay mask, the result is l*l
'''
index = torch.arange(l).to(self.decay)
mask = index[:, None] - index[None, :]
mask = mask.abs()
mask = mask * self.decay[:, None, None]
return mask
def forward(self, HW_tuple: Tuple[int], depth_map, split_or_not=False):
'''
depth_map: depth patches
HW_tuple: (H, W)
H * W == l
'''
depth_map = F.interpolate(depth_map, size=HW_tuple,mode='bilinear',align_corners=False)
if split_or_not:
index = torch.arange(HW_tuple[0]*HW_tuple[1]).to(self.decay)
sin = torch.sin(index[:, None] * self.angle[None, :])
sin = sin.reshape(HW_tuple[0], HW_tuple[1], -1)
cos = torch.cos(index[:, None] * self.angle[None, :])
cos = cos.reshape(HW_tuple[0], HW_tuple[1], -1)
mask_d_h = self.generate_1d_depth_decay(HW_tuple[0], HW_tuple[1], depth_map.transpose(-2,-1))
mask_d_w = self.generate_1d_depth_decay(HW_tuple[1], HW_tuple[0], depth_map)
mask_h = self.generate_1d_decay(HW_tuple[0])
mask_w = self.generate_1d_decay(HW_tuple[1])
mask_h = self.weight[0]*mask_h.unsqueeze(0).unsqueeze(2) + self.weight[1]*mask_d_h
mask_w = self.weight[0]*mask_w.unsqueeze(0).unsqueeze(2) + self.weight[1]*mask_d_w
geo_prior = ((sin, cos), (mask_h, mask_w))
else:
index = torch.arange(HW_tuple[0]*HW_tuple[1]).to(self.decay)
sin = torch.sin(index[:, None] * self.angle[None, :])
sin = sin.reshape(HW_tuple[0], HW_tuple[1], -1)
cos = torch.cos(index[:, None] * self.angle[None, :])
cos = cos.reshape(HW_tuple[0], HW_tuple[1], -1)
mask = self.generate_pos_decay(HW_tuple[0], HW_tuple[1])
mask_d = self.generate_depth_decay(HW_tuple[0], HW_tuple[1], depth_map)
mask = (self.weight[0]*mask+self.weight[1]*mask_d)
geo_prior = ((sin, cos), mask)
return geo_prior
class Decomposed_GSA(nn.Module):
def __init__(self, embed_dim, num_heads, value_factor=1):
super().__init__()
self.factor = value_factor
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = self.embed_dim * self.factor // num_heads
self.key_dim = self.embed_dim // num_heads
self.scaling = self.key_dim ** -0.5
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.v_proj = nn.Linear(embed_dim, embed_dim * self.factor, bias=True)
self.lepe = DWConv2d(embed_dim, 5, 1, 2)
self.out_proj = nn.Linear(embed_dim*self.factor, embed_dim, bias=True)
self.reset_parameters()
def forward(self, x: torch.Tensor, rel_pos, split_or_not=False):
bsz, h, w, _ = x.size()
(sin, cos), (mask_h, mask_w) = rel_pos
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
lepe = self.lepe(v)
k = k * self.scaling
q = q.view(bsz, h, w, self.num_heads, self.key_dim).permute(0, 3, 1, 2, 4) #(b n h w d1)
k = k.view(bsz, h, w, self.num_heads, self.key_dim).permute(0, 3, 1, 2, 4) #(b n h w d1)
qr = angle_transform(q, sin, cos)
kr = angle_transform(k, sin, cos)
qr_w = qr.transpose(1, 2)
kr_w = kr.transpose(1, 2)
v = v.reshape(bsz, h, w, self.num_heads, -1).permute(0, 1, 3, 2, 4)
qk_mat_w = qr_w @ kr_w.transpose(-1, -2)
qk_mat_w = qk_mat_w + mask_w.transpose(1,2)
qk_mat_w = torch.softmax(qk_mat_w, -1)
v = torch.matmul(qk_mat_w, v)
qr_h = qr.permute(0, 3, 1, 2, 4)
kr_h = kr.permute(0, 3, 1, 2, 4)
v = v.permute(0, 3, 2, 1, 4)
qk_mat_h = qr_h @ kr_h.transpose(-1, -2)
qk_mat_h = qk_mat_h + mask_h.transpose(1,2)
qk_mat_h = torch.softmax(qk_mat_h, -1)
output = torch.matmul(qk_mat_h, v)
output = output.permute(0, 3, 1, 2, 4).flatten(-2, -1)
output = output + lepe
output = self.out_proj(output)
return output
def reset_parameters(self):
nn.init.xavier_normal_(self.q_proj.weight, gain=2 ** -2.5)
nn.init.xavier_normal_(self.k_proj.weight, gain=2 ** -2.5)
nn.init.xavier_normal_(self.v_proj.weight, gain=2 ** -2.5)
nn.init.xavier_normal_(self.out_proj.weight)
nn.init.constant_(self.out_proj.bias, 0.0)
class Full_GSA(nn.Module):
def __init__(self, embed_dim, num_heads, value_factor=1):
super().__init__()
self.factor = value_factor
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = self.embed_dim * self.factor // num_heads
self.key_dim = self.embed_dim // num_heads
self.scaling = self.key_dim ** -0.5
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.v_proj = nn.Linear(embed_dim, embed_dim * self.factor, bias=True)
self.lepe = DWConv2d(embed_dim, 5, 1, 2)
self.out_proj = nn.Linear(embed_dim*self.factor, embed_dim, bias=True)
self.reset_parameters()
def forward(self, x: torch.Tensor, rel_pos, split_or_not=False):
'''
x: (b h w c)
rel_pos: mask: (n l l)
'''
bsz, h, w, _ = x.size()
(sin, cos), mask = rel_pos
assert h*w == mask.size(3)
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
lepe = self.lepe(v)
k = k * self.scaling
q = q.view(bsz, h, w, self.num_heads, -1).permute(0, 3, 1, 2, 4)
k = k.view(bsz, h, w, self.num_heads, -1).permute(0, 3, 1, 2, 4)
qr = angle_transform(q, sin, cos)
kr = angle_transform(k, sin, cos)
qr = qr.flatten(2, 3)
kr = kr.flatten(2, 3)
vr = v.reshape(bsz, h, w, self.num_heads, -1).permute(0, 3, 1, 2, 4)
vr = vr.flatten(2, 3)
qk_mat = qr @ kr.transpose(-1, -2)
qk_mat = qk_mat + mask
qk_mat = torch.softmax(qk_mat, -1)
output = torch.matmul(qk_mat, vr)
output = output.transpose(1, 2).reshape(bsz, h, w, -1)
output = output + lepe
output = self.out_proj(output)
return output
def reset_parameters(self):
nn.init.xavier_normal_(self.q_proj.weight, gain=2 ** -2.5)
nn.init.xavier_normal_(self.k_proj.weight, gain=2 ** -2.5)
nn.init.xavier_normal_(self.v_proj.weight, gain=2 ** -2.5)
nn.init.xavier_normal_(self.out_proj.weight)
nn.init.constant_(self.out_proj.bias, 0.0)
class FeedForwardNetwork(nn.Module):
def __init__(
self,
embed_dim,
ffn_dim,
activation_fn=F.gelu,
dropout=0.0,
activation_dropout=0.0,
layernorm_eps=1e-6,
subln=False,
subconv=True
):
super().__init__()
self.embed_dim = embed_dim
self.activation_fn = activation_fn
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
self.dropout_module = torch.nn.Dropout(dropout)
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
self.ffn_layernorm = nn.LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
self.dwconv = DWConv2d(ffn_dim, 3, 1, 1) if subconv else None
def reset_parameters(self):
self.fc1.reset_parameters()
self.fc2.reset_parameters()
if self.ffn_layernorm is not None:
self.ffn_layernorm.reset_parameters()
def forward(self, x: torch.Tensor):
'''
input shape: (b h w c)
'''
x = self.fc1(x)
x = self.activation_fn(x)
x = self.activation_dropout_module(x)
residual = x
if self.dwconv is not None:
x = self.dwconv(x)
if self.ffn_layernorm is not None:
x = self.ffn_layernorm(x)
x = x + residual
x = self.fc2(x)
x = self.dropout_module(x)
return x
class RGBD_Block(nn.Module):
def __init__(self, split_or_not: str, embed_dim: int, num_heads: int, ffn_dim: int, drop_path=0., layerscale=False, layer_init_values=1e-5, init_value=2, heads_range=4):
super().__init__()
self.layerscale = layerscale
self.embed_dim = embed_dim
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=1e-6)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=1e-6)
if split_or_not:
self.Attention = Decomposed_GSA(embed_dim, num_heads)
else:
self.Attention = Full_GSA(embed_dim, num_heads)
self.drop_path = DropPath(drop_path)
# FFN
self.ffn = FeedForwardNetwork(embed_dim, ffn_dim)
self.cnn_pos_encode = DWConv2d(embed_dim, 3, 1, 1)
# the function to generate the geometry prior for the current block
self.Geo = GeoPriorGen(embed_dim, num_heads, init_value, heads_range)
if layerscale:
self.gamma_1 = nn.Parameter(layer_init_values * torch.ones(1, 1, 1, embed_dim),requires_grad=True)
self.gamma_2 = nn.Parameter(layer_init_values * torch.ones(1, 1, 1, embed_dim),requires_grad=True)
def forward(
self,
x: torch.Tensor,
x_e: torch.Tensor,
split_or_not=False
):
x = x + self.cnn_pos_encode(x)
b, h, w, d = x.size()
geo_prior = self.Geo((h, w), x_e, split_or_not=split_or_not)
if self.layerscale:
x = x + self.drop_path(self.gamma_1 * self.Attention(self.layer_norm1(x), geo_prior, split_or_not))
x = x + self.drop_path(self.gamma_2 * self.ffn(self.layer_norm2(x)))
else:
x = x + self.drop_path(self.Attention(self.layer_norm1(x), geo_prior, split_or_not))
x = x + self.drop_path(self.ffn(self.layer_norm2(x)))
return x
class BasicLayer(nn.Module):
"""
A basic RGB-D layer in DFormerv2.
"""
def __init__(self, embed_dim, out_dim, depth, num_heads,
init_value: float, heads_range: float,
ffn_dim=96., drop_path=0., norm_layer=nn.LayerNorm, split_or_not=False,
downsample: PatchMerging=None, use_checkpoint=False,
layerscale=False, layer_init_values=1e-5):
super().__init__()
self.embed_dim = embed_dim
self.depth = depth
self.use_checkpoint = use_checkpoint
self.split_or_not = split_or_not
# build blocks
self.blocks = nn.ModuleList([
RGBD_Block(split_or_not, embed_dim, num_heads, ffn_dim,
drop_path[i] if isinstance(drop_path, list) else drop_path, layerscale, layer_init_values, init_value=init_value, heads_range=heads_range)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=embed_dim, out_dim=out_dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_e):
b, h, w, d = x.size()
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x=x, x_e=x_e, split_or_not=self.split_or_not)
else:
x = blk(x, x_e, split_or_not=self.split_or_not)
if self.downsample is not None:
x_down = self.downsample(x)
return x, x_down
else:
return x, x
class dformerv2(nn.Module):
def __init__(self, out_indices=(0, 1, 2, 3),
embed_dims=[64, 128, 256, 512], depths=[2, 2, 8, 2], num_heads=[4, 4, 8, 16],
init_values=[2, 2, 2, 2], heads_ranges=[4, 4, 6, 6], mlp_ratios=[4, 4, 3, 3], drop_path_rate=0.1, norm_layer=nn.LayerNorm,
patch_norm=True, use_checkpoint=False, projection=1024, norm_cfg = None,
layerscales=[False, False, False, False], layer_init_values=1e-6, norm_eval=True):
super().__init__()
self.out_indices = out_indices
self.num_layers = len(depths)
self.embed_dim = embed_dims[0]
self.patch_norm = patch_norm
self.num_features = embed_dims[-1]
self.mlp_ratios = mlp_ratios
self.norm_eval = norm_eval
# patch embedding
self.patch_embed = PatchEmbed(in_chans=3, embed_dim=embed_dims[0],
norm_layer=norm_layer if self.patch_norm else None)
# drop path rate
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
embed_dim=embed_dims[i_layer],
out_dim=embed_dims[i_layer+1] if (i_layer < self.num_layers - 1) else None,
depth=depths[i_layer],
num_heads=num_heads[i_layer],
init_value=init_values[i_layer],
heads_range=heads_ranges[i_layer],
ffn_dim=int(mlp_ratios[i_layer]*embed_dims[i_layer]),
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
split_or_not=(i_layer!=3),
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
layerscale=layerscales[i_layer],
layer_init_values=layer_init_values
)
self.layers.append(layer)
self.extra_norms = nn.ModuleList()
for i in range(3):
self.extra_norms.append(nn.LayerNorm(embed_dims[i+1]))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
try:
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
except:
pass
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
if isinstance(pretrained, str):
self.apply(_init_weights)
# logger = get_root_logger()
_state_dict = torch.load(pretrained)
if 'model' in _state_dict.keys():
_state_dict=_state_dict['model']
if 'state_dict' in _state_dict.keys():
_state_dict=_state_dict['state_dict']
state_dict = OrderedDict()
for k, v in _state_dict.items():
if k.startswith('backbone.'):
state_dict[k[9:]] = v
else:
state_dict[k] = v
print('load '+pretrained)
load_state_dict(self, state_dict, strict=False)
# load_checkpoint(self, pretrained, strict=False)
# load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
self.apply(_init_weights)
else:
raise TypeError('pretrained must be a str or None')
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward(self, x, x_e):
# rgb input
x = self.patch_embed(x)
# depth input
x_e = x_e[:,0,:,:].unsqueeze(1)
outs = []
for i in range(self.num_layers):
layer = self.layers[i]
x_out, x = layer(x, x_e)
if i in self.out_indices:
if i != 0:
x_out = self.extra_norms[i-1](x_out)
out = x_out.permute(0, 3, 1, 2).contiguous()
outs.append(out)
return tuple(outs)
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super().train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, nn.BatchNorm2d):
m.eval()
def DFormerv2_S(pretrained=False, **kwargs):
model = dformerv2(embed_dims=[64, 128, 256, 512], depths=[3, 4, 18, 4], num_heads=[4, 4, 8, 16],
heads_ranges=[4, 4, 6, 6], **kwargs)
return model
def DFormerv2_B(pretrained=False, **kwargs):
model = dformerv2(embed_dims=[80, 160, 320, 512], depths=[4, 8, 25, 8], num_heads=[5, 5, 10, 16],
heads_ranges=[5, 5, 6, 6],
layerscales=[False, False, True, True],
layer_init_values=1e-6, **kwargs)
return model
def DFormerv2_L(pretrained=False, **kwargs):
model = dformerv2(embed_dims=[112, 224, 448, 640], depths=[4, 8, 25, 8], num_heads=[7, 7, 14, 20],
heads_ranges=[6, 6, 6, 6],
layerscales=[False, False, True, True],
layer_init_values=1e-6, **kwargs)
return model