|
import torch
|
|
import torch.nn as nn
|
|
from models.encoder import SparseConvNeXtLayerNorm, _get_active_ex_or_ii
|
|
from typing import Optional, Sequence, Tuple, Union, List
|
|
import numpy as np
|
|
from models.mamba.bi_vision_mamba import Mamba
|
|
from monai.networks.blocks.unetr_block import UnetrUpBlock
|
|
|
|
def build_3d_sincos_position_embedding(grid_size, embed_dim, num_tokens=0, temperature=10000.):
|
|
grid_size = (grid_size, grid_size, grid_size)
|
|
h, w, d = grid_size
|
|
grid_h = torch.arange(h, dtype=torch.float32)
|
|
grid_w = torch.arange(w, dtype=torch.float32)
|
|
grid_d = torch.arange(d, dtype=torch.float32)
|
|
|
|
grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d)
|
|
assert embed_dim % 6 == 0, 'Embed dimension must be divisible by 6 for 3D sin-cos position embedding'
|
|
pos_dim = embed_dim // 6
|
|
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
|
omega = 1. / (temperature ** omega)
|
|
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
|
|
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
|
|
out_d = torch.einsum('m,d->md', [grid_d.flatten(), omega])
|
|
pos_emb = torch.cat(
|
|
[torch.sin(out_h), torch.cos(out_h), torch.sin(out_w), torch.cos(out_w), torch.sin(out_d), torch.cos(out_d)],
|
|
dim=1)[None, :, :]
|
|
|
|
assert num_tokens == 1 or num_tokens == 0, "Number of tokens must be of 0 or 1"
|
|
if num_tokens == 1:
|
|
pe_token = torch.zeros([1, 1, embed_dim], dtype=torch.float32)
|
|
pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
|
|
else:
|
|
pos_embed = nn.Parameter(pos_emb)
|
|
pos_embed.requires_grad = False
|
|
return pos_embed
|
|
|
|
|
|
class MlpChannel(nn.Module):
|
|
def __init__(self, hidden_size, mlp_dim):
|
|
super().__init__()
|
|
self.fc1 = nn.Linear(hidden_size, mlp_dim)
|
|
self.act = nn.GELU()
|
|
self.fc2 = nn.Linear(mlp_dim, hidden_size)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.fc2(x)
|
|
return x
|
|
|
|
|
|
class MambaLayer(nn.Module):
|
|
def __init__(self, dim, d_state=16, d_conv=4, expand=2):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.norm1 = nn.LayerNorm(dim)
|
|
self.mamba = Mamba(
|
|
d_model=dim,
|
|
d_state=d_state,
|
|
d_conv=d_conv,
|
|
expand=expand,
|
|
bimamba_type="v1",
|
|
)
|
|
self.mlp = MlpChannel(hidden_size=dim, mlp_dim=2 * dim)
|
|
self.norm2 = nn.LayerNorm(dim)
|
|
def forward(self, x):
|
|
x = self.mamba(self.norm1(x)) + x
|
|
x = self.mlp(self.norm2(x)) + x
|
|
return x
|
|
|
|
|
|
class MaskedAutoencoderMamba(nn.Module):
|
|
""" Masked Autoencoder with VisionTransformer backbone
|
|
"""
|
|
|
|
def __init__(self, img_size=96, downsample_rato=16, embed_dim=384, depth=8, norm_layer=nn.LayerNorm, sparse=True):
|
|
super().__init__()
|
|
print("mamba sparse: ", sparse)
|
|
|
|
|
|
self.grid_size = img_size // downsample_rato
|
|
self.num_patches = (self.grid_size) ** 3
|
|
self.embed_dim = embed_dim
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim),
|
|
requires_grad=False)
|
|
|
|
self.blocks = nn.ModuleList([
|
|
MambaLayer(dim=embed_dim)
|
|
for i in range(depth)])
|
|
|
|
|
|
self.sparse = sparse
|
|
if self.sparse:
|
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
|
|
self.initialize_weights()
|
|
|
|
def initialize_weights(self):
|
|
|
|
|
|
pos_embed = build_3d_sincos_position_embedding(self.grid_size, self.embed_dim)
|
|
self.pos_embed.data.copy_(pos_embed)
|
|
if self.sparse:
|
|
torch.nn.init.normal_(self.mask_token, std=.02)
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
def _init_weights(self, m):
|
|
if isinstance(m, nn.Linear):
|
|
|
|
torch.nn.init.xavier_uniform_(m.weight)
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.LayerNorm):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
def random_masking(self, enc, active_b1fff):
|
|
"""
|
|
Perform per-sample random masking by per-sample shuffling.
|
|
Per-sample shuffling is done by argsort random noise.
|
|
x: [N, L, D], sequence
|
|
"""
|
|
N, L, D = enc.shape
|
|
mask = torch.tensor(active_b1fff, dtype=torch.int).flatten(2).transpose(1, 2)
|
|
|
|
noise = 1 - mask
|
|
len_keep = torch.sum(mask)
|
|
ids_shuffle = torch.argsort(noise, dim=1)
|
|
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep]
|
|
x_masked = torch.gather(enc, dim=1, index=ids_keep.repeat(1, 1, D))
|
|
|
|
|
|
return x_masked, mask, ids_restore
|
|
|
|
def unmasking(self, x, ids_restore):
|
|
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
|
|
x_ = torch.cat([x, mask_tokens], dim=1)
|
|
x = torch.gather(x_, dim=1, index=ids_restore.repeat(1, 1, x.shape[2]))
|
|
return x
|
|
|
|
def forward_encoder(self, enc, active_b1fff=None):
|
|
|
|
B, C, H, W, D = enc.shape
|
|
x = enc.flatten(2).transpose(1, 2)
|
|
|
|
x = x + self.pos_embed
|
|
if self.sparse:
|
|
|
|
x, mask, ids_restore = self.random_masking(x, active_b1fff)
|
|
|
|
for blk in self.blocks:
|
|
x = blk(x)
|
|
x = self.unmasking(x, ids_restore)
|
|
else:
|
|
for blk in self.blocks:
|
|
x = blk(x)
|
|
x = x.transpose(1, 2).reshape(B, C, H, W, D)
|
|
return x
|
|
|
|
def forward(self, imgs, active_b1fff=None):
|
|
return self.forward_encoder(imgs, active_b1fff)
|
|
|
|
|
|
class MedNeXtBlock(nn.Module):
|
|
def __init__(self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
exp_r: int = 4,
|
|
kernel_size: int = 7,
|
|
do_res: int = True,
|
|
n_groups: int or None = None,
|
|
sparse=False):
|
|
|
|
super().__init__()
|
|
|
|
self.do_res = do_res
|
|
self.sparse = sparse
|
|
conv = nn.Conv3d
|
|
|
|
|
|
self.conv1 = conv(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
kernel_size=kernel_size,
|
|
stride=1,
|
|
padding=kernel_size // 2,
|
|
groups=in_channels if n_groups is None else n_groups,
|
|
)
|
|
|
|
|
|
|
|
self.norm = SparseConvNeXtLayerNorm(normalized_shape=in_channels, data_format='channels_first', sparse=sparse)
|
|
|
|
|
|
self.conv2 = conv(
|
|
in_channels=in_channels,
|
|
out_channels=exp_r * in_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0
|
|
)
|
|
|
|
|
|
self.act = nn.GELU()
|
|
|
|
|
|
self.conv3 = conv(
|
|
in_channels=exp_r * in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0
|
|
)
|
|
|
|
def forward(self, x, dummy_tensor=None):
|
|
|
|
x1 = x
|
|
x1 = self.conv1(x1)
|
|
x1 = self.act(self.conv2(self.norm(x1)))
|
|
x1 = self.conv3(x1)
|
|
if self.sparse:
|
|
x1 *= _get_active_ex_or_ii(H=x1.shape[2], W=x1.shape[3], D=x1.shape[4], returning_active_ex=True)
|
|
if self.do_res:
|
|
x1 = x + x1
|
|
return x1
|
|
|
|
|
|
class MedNeXtDownBlock(MedNeXtBlock):
|
|
|
|
def __init__(self, in_channels, out_channels, exp_r=4, kernel_size=7,
|
|
do_res=False, sparse=False):
|
|
|
|
super().__init__(in_channels, out_channels, exp_r, kernel_size,
|
|
do_res=False, sparse=sparse)
|
|
|
|
self.resample_do_res = do_res
|
|
if do_res:
|
|
self.res_conv = nn.Conv3d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=1,
|
|
stride=2
|
|
)
|
|
|
|
self.conv1 = nn.Conv3d(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
kernel_size=kernel_size,
|
|
stride=2,
|
|
padding=kernel_size // 2,
|
|
groups=in_channels,
|
|
)
|
|
|
|
def forward(self, x, dummy_tensor=None):
|
|
|
|
x1 = super().forward(x)
|
|
if self.resample_do_res:
|
|
res = self.res_conv(x)
|
|
x1 = x1 + res
|
|
|
|
return x1
|
|
|
|
|
|
class UnetResBlock(nn.Module):
|
|
"""
|
|
A skip-connection based module that can be used for DynUNet, based on:
|
|
`Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
|
|
`nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.
|
|
|
|
Args:
|
|
spatial_dims: number of spatial dimensions.
|
|
in_channels: number of input channels.
|
|
out_channels: number of output channels.
|
|
kernel_size: convolution kernel size.
|
|
stride: convolution stride.
|
|
norm_name: feature normalization type and arguments.
|
|
act_name: activation layer type and arguments.
|
|
dropout: dropout probability.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
sparse: bool,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: Union[Sequence[int], int],
|
|
stride: Union[Sequence[int], int],
|
|
):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv3d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=kernel_size // 2)
|
|
self.conv2 = nn.Conv3d(
|
|
out_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=1,
|
|
padding=kernel_size // 2,
|
|
)
|
|
self.lrelu = nn.LeakyReLU(inplace=True, negative_slope=0.01)
|
|
self.norm1 = SparseConvNeXtLayerNorm(normalized_shape=out_channels, data_format='channels_first', sparse=sparse)
|
|
self.norm2 = SparseConvNeXtLayerNorm(normalized_shape=out_channels, data_format='channels_first', sparse=sparse)
|
|
self.downsample = in_channels != out_channels
|
|
stride_np = np.atleast_1d(stride)
|
|
if not np.all(stride_np == 1):
|
|
self.downsample = True
|
|
if self.downsample:
|
|
self.conv3 = nn.Conv3d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=stride)
|
|
self.norm3 = SparseConvNeXtLayerNorm(normalized_shape=out_channels, data_format='channels_first', sparse=sparse)
|
|
|
|
def forward(self, inp):
|
|
residual = inp
|
|
out = self.conv1(inp)
|
|
out = self.norm1(out)
|
|
out = self.lrelu(out)
|
|
out = self.conv2(out)
|
|
out = self.norm2(out)
|
|
if hasattr(self, "conv3"):
|
|
residual = self.conv3(residual)
|
|
if hasattr(self, "norm3"):
|
|
residual = self.norm3(residual)
|
|
out += residual
|
|
out = self.lrelu(out)
|
|
return out
|
|
|
|
|
|
class MedNeXtUpBlock(MedNeXtBlock):
|
|
|
|
def __init__(self, in_channels, out_channels, exp_r=4, kernel_size=3,
|
|
do_res=True, sparse=False):
|
|
super().__init__(in_channels, out_channels, exp_r, kernel_size,
|
|
do_res=False, sparse=sparse)
|
|
|
|
self.resample_do_res = do_res
|
|
|
|
conv = nn.ConvTranspose3d
|
|
if do_res:
|
|
self.res_conv = conv(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=1,
|
|
stride=2
|
|
)
|
|
|
|
self.conv1 = conv(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
kernel_size=kernel_size,
|
|
stride=2,
|
|
padding=kernel_size // 2,
|
|
groups=in_channels,
|
|
)
|
|
|
|
def forward(self, x, dummy_tensor=None):
|
|
|
|
x1 = super().forward(x)
|
|
|
|
x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0, 1, 0))
|
|
|
|
if self.resample_do_res:
|
|
res = self.res_conv(x)
|
|
res = torch.nn.functional.pad(res, (1, 0, 1, 0, 1, 0))
|
|
x1 = x1 + res
|
|
return x1
|
|
|
|
|
|
class UnetOutBlock(nn.Module):
|
|
def __init__(self, in_channels: int, n_classes: int):
|
|
super().__init__()
|
|
self.conv = nn.Conv3d(
|
|
in_channels,
|
|
n_classes,
|
|
kernel_size=1,
|
|
stride=1,
|
|
bias=True,
|
|
)
|
|
|
|
def forward(self, inp):
|
|
return self.conv(inp)
|
|
|
|
|
|
class Embeddings(nn.Module):
|
|
def __init__(self,
|
|
in_channel: int = 3,
|
|
channels: Tuple = (32, 64, 96, 128, 192),
|
|
depths: Tuple = (1, 1, 3, 1, 1),
|
|
kernels: Tuple = (3, 3, 3, 3, 3),
|
|
exp_r: Tuple = (2, 4, 4, 4, 2),
|
|
sparse=True):
|
|
super(Embeddings, self).__init__()
|
|
self.dim = [channels[1], channels[2], channels[3], channels[4], channels[4]]
|
|
self.stem = nn.Conv3d(in_channels=in_channel, out_channels=channels[0], kernel_size=3, stride=1, padding=1)
|
|
|
|
|
|
self.layer2 = nn.Sequential(*[
|
|
MedNeXtBlock(
|
|
in_channels=channels[1],
|
|
out_channels=channels[1],
|
|
exp_r=exp_r[1],
|
|
kernel_size=kernels[1],
|
|
do_res=True,
|
|
sparse=sparse
|
|
)
|
|
for i in range(depths[1])])
|
|
|
|
self.layer3 = nn.Sequential(*[
|
|
MedNeXtBlock(
|
|
in_channels=channels[2],
|
|
out_channels=channels[2],
|
|
exp_r=exp_r[2],
|
|
kernel_size=kernels[2],
|
|
do_res=True,
|
|
sparse=sparse
|
|
)
|
|
for i in range(depths[2])])
|
|
|
|
self.layer4 = nn.Sequential(*[
|
|
MedNeXtBlock(
|
|
in_channels=channels[3],
|
|
out_channels=channels[3],
|
|
exp_r=exp_r[3],
|
|
kernel_size=kernels[3],
|
|
do_res=True,
|
|
sparse=sparse
|
|
)
|
|
for i in range(depths[3])])
|
|
|
|
self.layer5 = nn.Sequential(*[
|
|
MedNeXtBlock(
|
|
in_channels=channels[4],
|
|
out_channels=channels[4],
|
|
exp_r=exp_r[4],
|
|
kernel_size=kernels[4],
|
|
do_res=True,
|
|
sparse=sparse
|
|
)
|
|
for i in range(depths[4])])
|
|
|
|
self.down = nn.MaxPool3d((2, 2, 2))
|
|
self.expend1 = nn.Conv3d(in_channels=channels[0], out_channels=channels[1], kernel_size=3, stride=1, padding=1)
|
|
self.expend2 = nn.Conv3d(in_channels=channels[1], out_channels=channels[2], kernel_size=3, stride=1, padding=1)
|
|
self.expend3 = nn.Conv3d(in_channels=channels[2], out_channels=channels[3], kernel_size=3, stride=1, padding=1)
|
|
self.expend4 = nn.Conv3d(in_channels=channels[3], out_channels=channels[4], kernel_size=3, stride=1, padding=1)
|
|
|
|
self.encoder1 = UnetResBlock(
|
|
in_channels=channels[1],
|
|
out_channels=channels[1],
|
|
kernel_size=3,
|
|
stride=1,
|
|
sparse=sparse
|
|
)
|
|
self.encoder2 = UnetResBlock(
|
|
in_channels=channels[2],
|
|
out_channels=channels[2],
|
|
kernel_size=3,
|
|
stride=1,
|
|
sparse=sparse
|
|
)
|
|
self.encoder3 = UnetResBlock(
|
|
in_channels=channels[3],
|
|
out_channels=channels[3],
|
|
kernel_size=3,
|
|
stride=1,
|
|
sparse=sparse
|
|
)
|
|
self.encoder4 = UnetResBlock(
|
|
in_channels=channels[4],
|
|
out_channels=channels[4],
|
|
kernel_size=3,
|
|
stride=1,
|
|
sparse=sparse
|
|
)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
x = self.stem(x)
|
|
|
|
x1 = self.expend1(x)
|
|
|
|
x = self.down(x1)
|
|
x = self.layer2(x)
|
|
x2 = self.expend2(x)
|
|
|
|
x = self.down(x2)
|
|
x = self.layer3(x)
|
|
x3 = self.expend3(x)
|
|
|
|
x = self.down(x3)
|
|
x = self.layer4(x)
|
|
x4 = self.expend4(x)
|
|
|
|
x = self.down(x4)
|
|
x5 = self.layer5(x)
|
|
|
|
return self.encoder1(x1), self.encoder2(x2), self.encoder3(x3), self.encoder4(x4), x5
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_channel: int = 1,
|
|
channels=(32, 64, 128, 192, 384),
|
|
depths=(1, 2, 2, 2, 1),
|
|
kernels=(3, 3, 3, 3, 3),
|
|
exp_r=(2, 2, 4, 4, 4),
|
|
img_size=96,
|
|
depth=4,
|
|
norm_layer=nn.LayerNorm,
|
|
sparse=False):
|
|
super(Encoder, self).__init__()
|
|
self.dim = [channels[1], channels[2], channels[3], channels[4], channels[4]]
|
|
|
|
self.embeddings = Embeddings(in_channel=in_channel,
|
|
channels=channels,
|
|
depths=depths,
|
|
kernels=kernels,
|
|
exp_r=exp_r,
|
|
sparse=sparse)
|
|
|
|
self.mae = MaskedAutoencoderMamba(
|
|
img_size=img_size,
|
|
downsample_rato=self.get_downsample_ratio(),
|
|
embed_dim=channels[-1],
|
|
depth=depth,
|
|
norm_layer=norm_layer,
|
|
sparse=sparse)
|
|
|
|
def get_downsample_ratio(self) -> int:
|
|
"""
|
|
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
|
|
|
|
:return: the TOTAL downsample ratio of the ConvNet.
|
|
E.g., for a ResNet-50, this should return 32.
|
|
"""
|
|
return 16
|
|
|
|
def get_feature_map_channels(self) -> List[int]:
|
|
"""
|
|
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
|
|
|
|
:return: a list of the number of channels of each feature map.
|
|
E.g., for a ResNet-50, this should return [256, 512, 1024, 2048].
|
|
"""
|
|
return self.dim
|
|
|
|
def forward(self, x, active_b1fff=None):
|
|
x1, x2, x3, x4, x5 = self.embeddings(x)
|
|
_x5 = self.mae(x5, active_b1fff)
|
|
return x1, x2, x3, x4, _x5
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self,
|
|
n_classes: int = 3,
|
|
channels: Tuple = (32, 64, 128, 196, 384),
|
|
norm_name = "instance",
|
|
res_block: bool = True):
|
|
super(Decoder, self).__init__()
|
|
|
|
self.decoder5 = UnetrUpBlock(
|
|
spatial_dims=3,
|
|
in_channels=channels[4],
|
|
out_channels=channels[4],
|
|
kernel_size=3,
|
|
upsample_kernel_size=2,
|
|
norm_name=norm_name,
|
|
res_block=res_block,
|
|
)
|
|
self.decoder4 = UnetrUpBlock(
|
|
spatial_dims=3,
|
|
in_channels=channels[4],
|
|
out_channels=channels[3],
|
|
kernel_size=3,
|
|
upsample_kernel_size=2,
|
|
norm_name=norm_name,
|
|
res_block=res_block,
|
|
)
|
|
self.decoder3 = UnetrUpBlock(
|
|
spatial_dims=3,
|
|
in_channels=channels[3],
|
|
out_channels=channels[2],
|
|
kernel_size=3,
|
|
upsample_kernel_size=2,
|
|
norm_name=norm_name,
|
|
res_block=res_block,
|
|
)
|
|
self.decoder2 = UnetrUpBlock(
|
|
spatial_dims=3,
|
|
in_channels=channels[2],
|
|
out_channels=channels[1],
|
|
kernel_size=3,
|
|
upsample_kernel_size=2,
|
|
norm_name=norm_name,
|
|
res_block=res_block,
|
|
)
|
|
self.decoder1 = UnetResBlock(
|
|
in_channels=channels[1],
|
|
out_channels=channels[0],
|
|
kernel_size=3,
|
|
stride=1,
|
|
sparse=False
|
|
)
|
|
self.out = UnetOutBlock(in_channels=channels[0], n_classes=n_classes)
|
|
|
|
def forward(self, x1, x2, x3, x4, x5):
|
|
d4 = self.decoder5(x5, x4)
|
|
d3 = self.decoder4(d4, x3)
|
|
d2 = self.decoder3(d3, x2)
|
|
d1 = self.decoder2(d2, x1)
|
|
d0 = self.decoder1(d1)
|
|
return self.out(d0)
|
|
|
|
|
|
class Hybird(nn.Module):
|
|
def __init__(self,
|
|
in_channel: int = 3,
|
|
n_classes: int = 3,
|
|
channels: Tuple = (32, 64, 96, 128, 192),
|
|
depths: Tuple = (1, 1, 3, 3, 1),
|
|
kernels: Tuple = (3, 3, 3, 3, 3),
|
|
exp_r: Tuple = (2, 4, 4, 4, 2),
|
|
img_size=96,
|
|
depth=3,
|
|
norm_layer=nn.LayerNorm, ):
|
|
super().__init__()
|
|
self.embeddings = Embeddings(in_channel=in_channel,
|
|
channels=channels,
|
|
depths=depths,
|
|
kernels=kernels,
|
|
exp_r=exp_r,
|
|
sparse=False)
|
|
|
|
self.mae = MaskedAutoencoderMamba(
|
|
img_size=img_size,
|
|
downsample_rato=16,
|
|
embed_dim=channels[-1],
|
|
depth=depth,
|
|
norm_layer=norm_layer,
|
|
sparse=False)
|
|
|
|
self.decoder = Decoder(
|
|
n_classes=n_classes,
|
|
channels=channels,
|
|
)
|
|
|
|
def forward(self, x):
|
|
x1, x2, x3, x4, x5 = self.embeddings(x)
|
|
x5 = self.mae(x5, None)
|
|
return self.decoder(x1, x2, x3, x4, x5)
|
|
|
|
|
|
def build_hybird(in_channel=1, n_classes=14, img_size=96):
|
|
return Hybird(in_channel=in_channel,
|
|
n_classes=n_classes,
|
|
channels=(32, 64, 128, 192, 384),
|
|
depths=(1, 2, 2, 2, 1),
|
|
kernels=(3, 3, 3, 3, 3),
|
|
exp_r=(2, 2, 4, 4, 4),
|
|
img_size=img_size,
|
|
depth=4)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
x = torch.rand((1, 1, 96, 96, 96))
|
|
network = build_hybird()
|
|
print(network(x).shape)
|
|
|
|
|
|
|