project-monai's picture
Upload wholeBrainSeg_Large_UNEST_segmentation version 0.2.6
41c525d verified
#!/usr/bin/env python3
# limitations under the License.
"""
The 3D NEST transformer based segmentation model
MASI Lab, Vanderbilty University
Authors: Xin Yu, Yinchi Zhou, Yucheng Tang, Bennett Landman
The NEST code is partly from
Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and
Interpretable Visual Understanding
https://arxiv.org/pdf/2105.12723.pdf
"""
from typing import Sequence, Tuple, Union
import torch
import torch.nn as nn
from monai.networks.blocks import Convolution
from monai.networks.blocks.dynunet_block import UnetOutBlock
from scripts.networks.nest_transformer_3D import NestTransformer3D
from scripts.networks.unest_block import UNesTBlock, UNesTConvBlock, UNestUpBlock
class UNesT(nn.Module):
"""
UNesT model implementation
"""
def __init__(
self,
in_channels: int,
out_channels: int,
img_size: Sequence[int] = (96, 96, 96),
feature_size: int = 16,
patch_size: int = 2,
depths: Sequence[int] = (2, 2, 2, 2),
num_heads: Sequence[int] = (3, 6, 12, 24),
embed_dim: Sequence[int] = (128, 256, 512),
window_size: Sequence[int] = (7, 7, 7),
norm_name: Union[Tuple, str] = "instance",
conv_block: bool = False,
res_block: bool = True,
dropout_rate: float = 0.0,
) -> None:
"""
Args:
in_channels: dimension of input channels.
out_channels: dimension of output channels.
img_size: dimension of input image.
feature_size: dimension of network feature size.
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer.
num_heads: number of attention heads.
pos_embed: position embedding layer type.
norm_name: feature normalization type and arguments.
conv_block: bool argument to determine if convolutional block is used.
res_block: bool argument to determine if residual block is used.
dropout_rate: faction of the input units to drop.
Examples:
# for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm
>>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch')
# for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm
>>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance')
"""
super().__init__()
if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")
self.embed_dim = embed_dim
self.nestViT = NestTransformer3D(
img_size=96,
in_chans=1,
patch_size=patch_size,
num_levels=3,
embed_dims=embed_dim,
num_heads=num_heads,
depths=depths,
num_classes=1000,
mlp_ratio=4.0,
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.5,
norm_layer=None,
act_layer=None,
pad_type="",
weight_init="",
global_pool="avg",
)
self.encoder1 = UNesTConvBlock(
spatial_dims=3,
in_channels=1,
out_channels=feature_size * 2,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.encoder2 = UNestUpBlock(
spatial_dims=3,
in_channels=self.embed_dim[0],
out_channels=feature_size * 4,
num_layer=1,
kernel_size=3,
stride=1,
upsample_kernel_size=2,
norm_name=norm_name,
conv_block=False,
res_block=False,
)
self.encoder3 = UNesTConvBlock(
spatial_dims=3,
in_channels=self.embed_dim[0],
out_channels=8 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.encoder4 = UNesTConvBlock(
spatial_dims=3,
in_channels=self.embed_dim[1],
out_channels=16 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.decoder5 = UNesTBlock(
spatial_dims=3,
in_channels=2 * self.embed_dim[2],
out_channels=feature_size * 32,
stride=1,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder4 = UNesTBlock(
spatial_dims=3,
in_channels=self.embed_dim[2],
out_channels=feature_size * 16,
stride=1,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder3 = UNesTBlock(
spatial_dims=3,
in_channels=feature_size * 16,
out_channels=feature_size * 8,
stride=1,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder2 = UNesTBlock(
spatial_dims=3,
in_channels=feature_size * 8,
out_channels=feature_size * 4,
stride=1,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder1 = UNesTBlock(
spatial_dims=3,
in_channels=feature_size * 4,
out_channels=feature_size * 2,
stride=1,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.encoder10 = Convolution(
spatial_dims=3,
in_channels=32 * feature_size,
out_channels=64 * feature_size,
strides=2,
adn_ordering="ADN",
dropout=0.0,
)
self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) # type: ignore
def proj_feat(self, x, hidden_size, feat_size):
x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
x = x.permute(0, 4, 1, 2, 3).contiguous()
return x
def load_from(self, weights):
with torch.no_grad():
# copy weights from patch embedding
for i in weights["state_dict"]:
print(i)
self.vit.patch_embedding.position_embeddings.copy_(
weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"]
)
self.vit.patch_embedding.cls_token.copy_(
weights["state_dict"]["module.transformer.patch_embedding.cls_token"]
)
self.vit.patch_embedding.patch_embeddings[1].weight.copy_(
weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings_3d.1.weight"]
)
self.vit.patch_embedding.patch_embeddings[1].bias.copy_(
weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings_3d.1.bias"]
)
# copy weights from encoding blocks (default: num of blocks: 12)
for bname, block in self.vit.blocks.named_children():
print(block)
block.loadFrom(weights, n_block=bname)
# last norm layer of transformer
self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"])
self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"])
def forward(self, x_in):
x, hidden_states_out = self.nestViT(x_in)
enc0 = self.encoder1(x_in) # 2, 32, 96, 96, 96
x1 = hidden_states_out[0] # 2, 128, 24, 24, 24 2, 128, 12, 12, 12
enc1 = self.encoder2(x1) # 2, 64, 48, 48, 48 torch.Size([2, 64, 24, 24, 24])
x2 = hidden_states_out[1] # 2, 128, 24, 24, 24
enc2 = self.encoder3(x2) # 2, 128, 24, 24, 24 torch.Size([2, 128, 12, 12, 12])
x3 = hidden_states_out[2] # 2, 256, 12, 12, 12 torch.Size([2, 256, 6, 6, 6])
enc3 = self.encoder4(x3) # 2, 256, 12, 12, 12 torch.Size([2, 256, 6, 6, 6])
x4 = hidden_states_out[3]
enc4 = x4 # 2, 512, 6, 6, 6 torch.Size([2, 512, 3, 3, 3])
dec4 = x # 2, 512, 6, 6, 6 torch.Size([2, 512, 3, 3, 3])
dec4 = self.encoder10(dec4) # 2, 1024, 3, 3, 3 torch.Size([2, 1024, 2, 2, 2])
dec3 = self.decoder5(dec4, enc4) # 2, 512, 6, 6, 6
dec2 = self.decoder4(dec3, enc3) # 2, 256, 12, 12, 12
dec1 = self.decoder3(dec2, enc2) # 2, 128, 24, 24, 24
dec0 = self.decoder2(dec1, enc1) # 2, 64, 48, 48, 48
out = self.decoder1(dec0, enc0) # 2, 32, 96, 96, 96
logits = self.out(out)
return logits