|
import os
|
|
os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from huggingface_hub import PyTorchModelHubMixin
|
|
|
|
import numpy as np
|
|
import nibabel as nib
|
|
from skimage import morphology
|
|
|
|
import math
|
|
from scipy import ndimage
|
|
from medpy import metric
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
class ConvBlock(nn.Module):
|
|
def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
|
|
super(ConvBlock, self).__init__()
|
|
|
|
ops = []
|
|
for i in range(n_stages):
|
|
if i == 0:
|
|
input_channel = n_filters_in
|
|
else:
|
|
input_channel = n_filters_out
|
|
|
|
ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
|
|
if normalization == 'batchnorm':
|
|
ops.append(nn.BatchNorm3d(n_filters_out))
|
|
elif normalization == 'groupnorm':
|
|
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
|
|
elif normalization == 'instancenorm':
|
|
ops.append(nn.InstanceNorm3d(n_filters_out))
|
|
elif normalization != 'none':
|
|
assert False
|
|
ops.append(nn.ReLU(inplace=True))
|
|
|
|
self.conv = nn.Sequential(*ops)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class DownsamplingConvBlock(nn.Module):
|
|
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
|
|
super(DownsamplingConvBlock, self).__init__()
|
|
|
|
ops = []
|
|
if normalization != 'none':
|
|
ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
|
|
if normalization == 'batchnorm':
|
|
ops.append(nn.BatchNorm3d(n_filters_out))
|
|
elif normalization == 'groupnorm':
|
|
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
|
|
elif normalization == 'instancenorm':
|
|
ops.append(nn.InstanceNorm3d(n_filters_out))
|
|
else:
|
|
assert False
|
|
else:
|
|
ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
|
|
|
|
ops.append(nn.ReLU(inplace=True))
|
|
|
|
self.conv = nn.Sequential(*ops)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class UpsamplingDeconvBlock(nn.Module):
|
|
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
|
|
super(UpsamplingDeconvBlock, self).__init__()
|
|
|
|
ops = []
|
|
if normalization != 'none':
|
|
ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
|
|
if normalization == 'batchnorm':
|
|
ops.append(nn.BatchNorm3d(n_filters_out))
|
|
elif normalization == 'groupnorm':
|
|
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
|
|
elif normalization == 'instancenorm':
|
|
ops.append(nn.InstanceNorm3d(n_filters_out))
|
|
else:
|
|
assert False
|
|
else:
|
|
ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
|
|
|
|
ops.append(nn.ReLU(inplace=True))
|
|
|
|
self.conv = nn.Sequential(*ops)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class Upsampling(nn.Module):
|
|
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
|
|
super(Upsampling, self).__init__()
|
|
|
|
ops = []
|
|
ops.append(nn.Upsample(scale_factor=stride, mode='trilinear', align_corners=False))
|
|
ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
|
|
if normalization == 'batchnorm':
|
|
ops.append(nn.BatchNorm3d(n_filters_out))
|
|
elif normalization == 'groupnorm':
|
|
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
|
|
elif normalization == 'instancenorm':
|
|
ops.append(nn.InstanceNorm3d(n_filters_out))
|
|
elif normalization != 'none':
|
|
assert False
|
|
ops.append(nn.ReLU(inplace=True))
|
|
|
|
self.conv = nn.Sequential(*ops)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class ConnectNet(nn.Module):
|
|
def __init__(self, in_channels, out_channels, input_size):
|
|
super(ConnectNet, self).__init__()
|
|
self.encoder = nn.Sequential(
|
|
nn.Conv3d(in_channels, 128, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.MaxPool3d(kernel_size=2, stride=2),
|
|
nn.Conv3d(128, 64, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.MaxPool3d(kernel_size=2, stride=2)
|
|
)
|
|
|
|
self.decoder = nn.Sequential(
|
|
nn.ConvTranspose3d(64, 128, kernel_size=2, stride=2),
|
|
nn.ReLU(),
|
|
nn.ConvTranspose3d(128, out_channels, kernel_size=2, stride=2),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, x):
|
|
encoded = self.encoder(x)
|
|
decoded = self.decoder(encoded)
|
|
return decoded
|
|
|
|
|
|
class VNet(nn.Module):
|
|
def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
|
|
super(VNet, self).__init__()
|
|
self.has_dropout = has_dropout
|
|
|
|
self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
|
|
self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
|
|
|
|
self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
|
|
self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
|
|
|
|
self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
|
|
self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
|
|
|
|
self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
|
|
self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
|
|
|
|
self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
|
|
self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
|
|
|
|
self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
|
|
self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
|
|
|
|
self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
|
|
self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
|
|
|
|
self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
|
|
self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
|
|
|
|
self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
|
|
self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
|
|
|
|
self.dropout = nn.Dropout3d(p=0.5, inplace=False)
|
|
|
|
self.__init_weight()
|
|
|
|
def encoder(self, input):
|
|
x1 = self.block_one(input)
|
|
x1_dw = self.block_one_dw(x1)
|
|
|
|
x2 = self.block_two(x1_dw)
|
|
x2_dw = self.block_two_dw(x2)
|
|
|
|
x3 = self.block_three(x2_dw)
|
|
x3_dw = self.block_three_dw(x3)
|
|
|
|
x4 = self.block_four(x3_dw)
|
|
x4_dw = self.block_four_dw(x4)
|
|
|
|
x5 = self.block_five(x4_dw)
|
|
if self.has_dropout:
|
|
x5 = self.dropout(x5)
|
|
|
|
res = [x1, x2, x3, x4, x5]
|
|
|
|
return res
|
|
|
|
def decoder(self, features):
|
|
x1 = features[0]
|
|
x2 = features[1]
|
|
x3 = features[2]
|
|
x4 = features[3]
|
|
x5 = features[4]
|
|
|
|
x5_up = self.block_five_up(x5)
|
|
x5_up = x5_up + x4
|
|
|
|
x6 = self.block_six(x5_up)
|
|
x6_up = self.block_six_up(x6)
|
|
x6_up = x6_up + x3
|
|
|
|
x7 = self.block_seven(x6_up)
|
|
x7_up = self.block_seven_up(x7)
|
|
x7_up = x7_up + x2
|
|
|
|
x8 = self.block_eight(x7_up)
|
|
x8_up = self.block_eight_up(x8)
|
|
x8_up = x8_up + x1
|
|
x9 = self.block_nine(x8_up)
|
|
if self.has_dropout:
|
|
x9 = self.dropout(x9)
|
|
out = self.out_conv(x9)
|
|
return out
|
|
|
|
def forward(self, input, turnoff_drop=False):
|
|
if turnoff_drop:
|
|
has_dropout = self.has_dropout
|
|
self.has_dropout = False
|
|
features = self.encoder(input)
|
|
out = self.decoder(features)
|
|
if turnoff_drop:
|
|
self.has_dropout = has_dropout
|
|
return out
|
|
|
|
def __init_weight(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
|
|
torch.nn.init.kaiming_normal_(m.weight)
|
|
elif isinstance(m, nn.BatchNorm3d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
|
|
|
|
class VNet_roi(nn.Module):
|
|
def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
|
|
super(VNet_roi, self).__init__()
|
|
self.has_dropout = has_dropout
|
|
|
|
self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
|
|
self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
|
|
|
|
self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
|
|
self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
|
|
|
|
self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
|
|
self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
|
|
|
|
self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
|
|
self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
|
|
|
|
self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
|
|
self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
|
|
|
|
self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
|
|
self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
|
|
|
|
self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
|
|
self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
|
|
|
|
self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
|
|
self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
|
|
|
|
self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
|
|
self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
|
|
|
|
self.dropout = nn.Dropout3d(p=0.5, inplace=False)
|
|
|
|
|
|
def encoder(self, input):
|
|
x1 = self.block_one(input)
|
|
x1_dw = self.block_one_dw(x1)
|
|
|
|
x2 = self.block_two(x1_dw)
|
|
x2_dw = self.block_two_dw(x2)
|
|
|
|
x3 = self.block_three(x2_dw)
|
|
x3_dw = self.block_three_dw(x3)
|
|
|
|
x4 = self.block_four(x3_dw)
|
|
x4_dw = self.block_four_dw(x4)
|
|
|
|
x5 = self.block_five(x4_dw)
|
|
|
|
if self.has_dropout:
|
|
x5 = self.dropout(x5)
|
|
|
|
res = [x1, x2, x3, x4, x5]
|
|
|
|
return res
|
|
|
|
def decoder(self, features):
|
|
x1 = features[0]
|
|
x2 = features[1]
|
|
x3 = features[2]
|
|
x4 = features[3]
|
|
x5 = features[4]
|
|
|
|
x5_up = self.block_five_up(x5)
|
|
x5_up = x5_up + x4
|
|
|
|
x6 = self.block_six(x5_up)
|
|
x6_up = self.block_six_up(x6)
|
|
x6_up = x6_up + x3
|
|
|
|
x7 = self.block_seven(x6_up)
|
|
x7_up = self.block_seven_up(x7)
|
|
x7_up = x7_up + x2
|
|
|
|
x8 = self.block_eight(x7_up)
|
|
x8_up = self.block_eight_up(x8)
|
|
x8_up = x8_up + x1
|
|
x9 = self.block_nine(x8_up)
|
|
|
|
if self.has_dropout:
|
|
x9 = self.dropout(x9)
|
|
out = self.out_conv(x9)
|
|
return out
|
|
|
|
|
|
def forward(self, input, turnoff_drop=False):
|
|
if turnoff_drop:
|
|
has_dropout = self.has_dropout
|
|
self.has_dropout = False
|
|
features = self.encoder(input)
|
|
out = self.decoder(features)
|
|
if turnoff_drop:
|
|
self.has_dropout = has_dropout
|
|
return out
|
|
|
|
|
|
class ResVNet(nn.Module):
|
|
def __init__(self, n_channels=1, n_classes=2, n_filters=16, normalization='instancenorm', has_dropout=False):
|
|
super(ResVNet, self).__init__()
|
|
self.resencoder = resnet34()
|
|
self.has_dropout = has_dropout
|
|
|
|
self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
|
|
self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
|
|
|
|
self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
|
|
self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
|
|
|
|
self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
|
|
self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
|
|
|
|
self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
|
|
self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
|
|
|
|
self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
|
|
self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
|
|
|
|
self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
|
|
self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
|
|
|
|
self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
|
|
self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
|
|
|
|
self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
|
|
self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
|
|
|
|
|
|
self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
|
|
self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
|
|
|
|
|
|
if has_dropout:
|
|
self.dropout = nn.Dropout3d(p=0.5)
|
|
self.branchs = nn.ModuleList()
|
|
for i in range(1):
|
|
if has_dropout:
|
|
seq = nn.Sequential(
|
|
ConvBlock(1, n_filters, n_filters, normalization=normalization),
|
|
nn.Dropout3d(p=0.5),
|
|
nn.Conv3d(n_filters, n_classes, 1, padding=0)
|
|
)
|
|
else:
|
|
seq = nn.Sequential(
|
|
ConvBlock(1, n_filters, n_filters, normalization=normalization),
|
|
nn.Conv3d(n_filters, n_classes, 1, padding=0)
|
|
)
|
|
self.branchs.append(seq)
|
|
|
|
def encoder(self, input):
|
|
x1 = self.block_one(input)
|
|
x1_dw = self.block_one_dw(x1)
|
|
|
|
x2 = self.block_two(x1_dw)
|
|
x2_dw = self.block_two_dw(x2)
|
|
|
|
x3 = self.block_three(x2_dw)
|
|
x3_dw = self.block_three_dw(x3)
|
|
|
|
x4 = self.block_four(x3_dw)
|
|
x4_dw = self.block_four_dw(x4)
|
|
|
|
x5 = self.block_five(x4_dw)
|
|
|
|
if self.has_dropout:
|
|
x5 = self.dropout(x5)
|
|
|
|
res = [x1, x2, x3, x4, x5]
|
|
|
|
return res
|
|
|
|
def decoder(self, features):
|
|
x1 = features[0]
|
|
x2 = features[1]
|
|
x3 = features[2]
|
|
x4 = features[3]
|
|
x5 = features[4]
|
|
|
|
x5_up = self.block_five_up(x5)
|
|
x5_up = x5_up + x4
|
|
|
|
x6 = self.block_six(x5_up)
|
|
x6_up = self.block_six_up(x6)
|
|
x6_up = x6_up + x3
|
|
|
|
x7 = self.block_seven(x6_up)
|
|
x7_up = self.block_seven_up(x7)
|
|
x7_up = x7_up + x2
|
|
|
|
x8 = self.block_eight(x7_up)
|
|
x8_up = self.block_eight_up(x8)
|
|
x8_up = x8_up + x1
|
|
|
|
|
|
x9 = self.block_nine(x8_up)
|
|
|
|
out = self.out_conv(x9)
|
|
|
|
|
|
return out
|
|
|
|
def forward(self, input, turnoff_drop=False):
|
|
if turnoff_drop:
|
|
has_dropout = self.has_dropout
|
|
self.has_dropout = False
|
|
features = self.resencoder(input)
|
|
out = self.decoder(features)
|
|
if turnoff_drop:
|
|
self.has_dropout = has_dropout
|
|
return out
|
|
|
|
|
|
__all__ = ['ResNet', 'resnet34']
|
|
|
|
|
|
def conv3x3(in_planes, out_planes, stride=1):
|
|
"""3x3 convolution with padding"""
|
|
return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
|
|
|
|
|
def conv3x3_bn_relu(in_planes, out_planes, stride=1):
|
|
return nn.Sequential(
|
|
conv3x3(in_planes, out_planes, stride),
|
|
nn.InstanceNorm3d(out_planes),
|
|
nn.ReLU()
|
|
)
|
|
|
|
|
|
class BasicBlock(nn.Module):
|
|
expansion = 1
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
|
groups=1, base_width=64, dilation=-1):
|
|
super(BasicBlock, self).__init__()
|
|
if groups != 1 or base_width != 64:
|
|
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
|
self.conv1 = conv3x3(inplanes, planes, stride)
|
|
self.bn1 = nn.InstanceNorm3d(planes)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.conv2 = conv3x3(planes, planes)
|
|
self.bn2 = nn.InstanceNorm3d(planes)
|
|
self.downsample = downsample
|
|
self.stride = stride
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
if self.downsample is not None:
|
|
residual = self.downsample(x)
|
|
|
|
out += residual
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
|
|
class Bottleneck(nn.Module):
|
|
expansion = 4
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
|
groups=1, base_width=64, dilation=1):
|
|
super(Bottleneck, self).__init__()
|
|
width = int(planes * (base_width / 64.)) * groups
|
|
self.conv1 = nn.Conv3d(inplanes, width, kernel_size=1, bias=False)
|
|
self.bn1 = nn.InstanceNorm3d(width)
|
|
self.conv2 = nn.Conv3d(width, width, kernel_size=3, stride=stride, dilation=dilation,
|
|
padding=dilation, groups=groups, bias=False)
|
|
self.bn2 = nn.InstanceNorm3d(width)
|
|
self.conv3 = nn.Conv3d(width, planes * self.expansion, kernel_size=1, bias=False)
|
|
self.bn3 = nn.InstanceNorm3d(planes * self.expansion)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.downsample = downsample
|
|
self.stride = stride
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv3(out)
|
|
out = self.bn3(out)
|
|
|
|
if self.downsample is not None:
|
|
residual = self.downsample(x)
|
|
|
|
out += residual
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
|
|
class ResNet(nn.Module):
|
|
|
|
def __init__(self, block, layers, in_channel=1, width=1,
|
|
groups=1, width_per_group=64,
|
|
mid_dim=1024, low_dim=128,
|
|
avg_down=False, deep_stem=False,
|
|
head_type='mlp_head', layer4_dilation=1):
|
|
super(ResNet, self).__init__()
|
|
self.avg_down = avg_down
|
|
self.inplanes = 16 * width
|
|
self.base = int(16 * width)
|
|
self.groups = groups
|
|
self.base_width = width_per_group
|
|
|
|
mid_dim = self.base * 8 * block.expansion
|
|
|
|
if deep_stem:
|
|
self.conv1 = nn.Sequential(
|
|
conv3x3_bn_relu(in_channel, 32, stride=2),
|
|
conv3x3_bn_relu(32, 32, stride=1),
|
|
conv3x3(32, 64, stride=1)
|
|
)
|
|
else:
|
|
self.conv1 = nn.Conv3d(in_channel, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False)
|
|
|
|
self.bn1 = nn.InstanceNorm3d(self.inplanes)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
|
|
self.layer1 = self._make_layer(block, self.base*2, layers[0],stride=2)
|
|
self.layer2 = self._make_layer(block, self.base * 4, layers[1], stride=2)
|
|
self.layer3 = self._make_layer(block, self.base * 8, layers[2], stride=2)
|
|
if layer4_dilation == 1:
|
|
self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=2)
|
|
elif layer4_dilation == 2:
|
|
self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=1, dilation=2)
|
|
else:
|
|
raise NotImplementedError
|
|
self.avgpool = nn.AvgPool3d(7, stride=1)
|
|
|
|
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
|
|
downsample = None
|
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
if self.avg_down:
|
|
downsample = nn.Sequential(
|
|
nn.AvgPool3d(kernel_size=stride, stride=stride),
|
|
nn.Conv3d(self.inplanes, planes * block.expansion,
|
|
kernel_size=1, stride=1, bias=False),
|
|
nn.InstanceNorm3d(planes * block.expansion),
|
|
)
|
|
else:
|
|
downsample = nn.Sequential(
|
|
nn.Conv3d(self.inplanes, planes * block.expansion,
|
|
kernel_size=1, stride=stride, bias=False),
|
|
nn.InstanceNorm3d(planes * block.expansion),
|
|
)
|
|
|
|
layers = [block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, dilation)]
|
|
self.inplanes = planes * block.expansion
|
|
for _ in range(1, blocks):
|
|
layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=dilation))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
|
|
c2 = self.layer1(x)
|
|
c3 = self.layer2(c2)
|
|
c4 = self.layer3(c3)
|
|
c5 = self.layer4(c4)
|
|
|
|
|
|
return [x,c2,c3,c4,c5]
|
|
|
|
|
|
def resnet34(**kwargs):
|
|
return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
|
|
|
|
|
def label_rescale(image_label, w_ori, h_ori, z_ori, flag):
|
|
w_ori, h_ori, z_ori = int(w_ori), int(h_ori), int(z_ori)
|
|
|
|
if flag == 'trilinear':
|
|
teeth_ids = np.unique(image_label)
|
|
image_label_ori = np.zeros((w_ori, h_ori, z_ori))
|
|
|
|
|
|
image_label = torch.from_numpy(image_label).cuda(0)
|
|
|
|
|
|
for label_id in range(len(teeth_ids)):
|
|
image_label_bn = (image_label == teeth_ids[label_id]).float()
|
|
image_label_bn = image_label_bn[None, None, :, :, :]
|
|
image_label_bn = torch.nn.functional.interpolate(image_label_bn, size=(w_ori, h_ori, z_ori),
|
|
mode='trilinear', align_corners=False)
|
|
image_label_bn = image_label_bn[0, 0, :, :, :]
|
|
image_label_bn = image_label_bn.cpu().data.numpy()
|
|
image_label_ori[image_label_bn > 0.5] = teeth_ids[label_id]
|
|
image_label = image_label_ori
|
|
|
|
if flag == 'nearest':
|
|
|
|
|
|
image_label = torch.from_numpy(image_label).cuda(0)
|
|
|
|
|
|
image_label = image_label[None, None, :, :, :].float()
|
|
image_label = torch.nn.functional.interpolate(image_label, size=(w_ori, h_ori, z_ori), mode='nearest')
|
|
image_label = image_label[0, 0, :, :, :].cpu().data.numpy()
|
|
return image_label
|
|
|
|
|
|
def img_crop(image_bbox):
|
|
if image_bbox.sum() > 0:
|
|
|
|
x_min = np.nonzero(image_bbox)[0].min() - 8
|
|
x_max = np.nonzero(image_bbox)[0].max() + 8
|
|
|
|
y_min = np.nonzero(image_bbox)[1].min() - 16
|
|
y_max = np.nonzero(image_bbox)[1].max() + 16
|
|
|
|
z_min = np.nonzero(image_bbox)[2].min() - 16
|
|
z_max = np.nonzero(image_bbox)[2].max() + 16
|
|
|
|
if x_min < 0:
|
|
x_min = 0
|
|
if y_min < 0:
|
|
y_min = 0
|
|
if z_min < 0:
|
|
z_min = 0
|
|
if x_max > image_bbox.shape[0]:
|
|
x_max = image_bbox.shape[0]
|
|
if y_max > image_bbox.shape[1]:
|
|
y_max = image_bbox.shape[1]
|
|
if z_max > image_bbox.shape[2]:
|
|
z_max = image_bbox.shape[2]
|
|
|
|
if (x_max - x_min) % 16 != 0:
|
|
x_max -= (x_max - x_min) % 16
|
|
if (y_max - y_min) % 16 != 0:
|
|
y_max -= (y_max - y_min) % 16
|
|
if (z_max - z_min) % 16 != 0:
|
|
z_max -= (z_max - z_min) % 16
|
|
|
|
if image_bbox.sum() == 0:
|
|
x_min, x_max, y_min, y_max, z_min, z_max = -1, image_bbox.shape[0], 0, image_bbox.shape[1], 0, image_bbox.shape[
|
|
2]
|
|
return x_min, x_max, y_min, y_max, z_min, z_max
|
|
|
|
|
|
def roi_extraction(image, net_roi, ids):
|
|
w, h, d = image.shape
|
|
|
|
print('---run the roi binary segmentation.')
|
|
|
|
stride_xy = 32
|
|
stride_z = 16
|
|
patch_size_roi_stage = (112, 112, 80)
|
|
|
|
label_roi = roi_detection(net_roi, image[0:w:2, 0:h:2, 0:d:2], stride_xy, stride_z,
|
|
patch_size_roi_stage)
|
|
print(label_roi.shape, np.max(label_roi))
|
|
label_roi = label_rescale(label_roi, w, h, d, 'trilinear')
|
|
|
|
label_roi = morphology.remove_small_objects(label_roi.astype(bool), 5000, connectivity=3).astype(float)
|
|
|
|
label_roi = ndimage.grey_dilation(label_roi, size=(5, 5, 5))
|
|
|
|
label_roi = morphology.remove_small_objects(label_roi.astype(bool), 400000, connectivity=3).astype(
|
|
float)
|
|
|
|
label_roi = ndimage.grey_erosion(label_roi, size=(5, 5, 5))
|
|
|
|
|
|
x_min, x_max, y_min, y_max, z_min, z_max = img_crop(label_roi)
|
|
if x_min == -1:
|
|
whole_label = np.zeros((w, h, d))
|
|
return whole_label
|
|
image = image[x_min:x_max, y_min:y_max, z_min:z_max]
|
|
print("image shape(after roi): ", image.shape)
|
|
|
|
return image, x_min, x_max, y_min, y_max, z_min, z_max
|
|
|
|
|
|
def roi_detection(net, image, stride_xy, stride_z, patch_size):
|
|
w, h, d = image.shape
|
|
|
|
|
|
add_pad = False
|
|
if w < patch_size[0]:
|
|
w_pad = patch_size[0] - w
|
|
add_pad = True
|
|
else:
|
|
w_pad = 0
|
|
if h < patch_size[1]:
|
|
h_pad = patch_size[1] - h
|
|
add_pad = True
|
|
else:
|
|
h_pad = 0
|
|
if d < patch_size[2]:
|
|
d_pad = patch_size[2] - d
|
|
add_pad = True
|
|
else:
|
|
d_pad = 0
|
|
wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2
|
|
hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2
|
|
dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2
|
|
if add_pad:
|
|
image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), (dl_pad, dr_pad)], mode='constant',
|
|
constant_values=0)
|
|
ww, hh, dd = image.shape
|
|
|
|
sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
|
|
sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
|
|
sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
|
|
score_map = np.zeros((2,) + image.shape).astype(np.float32)
|
|
cnt = np.zeros(image.shape).astype(np.float32)
|
|
count = 0
|
|
for x in range(0, sx):
|
|
xs = min(stride_xy * x, ww - patch_size[0])
|
|
for y in range(0, sy):
|
|
ys = min(stride_xy * y, hh - patch_size[1])
|
|
for z in range(0, sz):
|
|
zs = min(stride_z * z, dd - patch_size[2])
|
|
test_patch = image[xs:xs + patch_size[0], ys:ys + patch_size[1],
|
|
zs:zs + patch_size[2]]
|
|
test_patch = np.expand_dims(np.expand_dims(test_patch, axis=0), axis=0).astype(
|
|
np.float32)
|
|
|
|
|
|
test_patch = torch.from_numpy(test_patch).cuda(0)
|
|
|
|
|
|
with torch.no_grad():
|
|
y1 = net(test_patch)
|
|
y = F.softmax(y1, dim=1)
|
|
y = y.cpu().data.numpy()
|
|
y = y[0, :, :, :, :]
|
|
score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \
|
|
= score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1],
|
|
zs:zs + patch_size[2]] + y
|
|
cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \
|
|
= cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1
|
|
count = count + 1
|
|
score_map = score_map / np.expand_dims(cnt, axis=0)
|
|
|
|
label_map = np.argmax(score_map, axis=0)
|
|
if add_pad:
|
|
label_map = label_map[wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d]
|
|
score_map = score_map[:, wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d]
|
|
return label_map
|
|
|
|
|
|
def test_single_case_array(model_array, image=None, stride_xy=None, stride_z=None, patch_size=None, num_classes=1):
|
|
w, h, d = image.shape
|
|
|
|
|
|
add_pad = False
|
|
if w < patch_size[0]:
|
|
w_pad = patch_size[0]-w
|
|
add_pad = True
|
|
else:
|
|
w_pad = 0
|
|
if h < patch_size[1]:
|
|
h_pad = patch_size[1]-h
|
|
add_pad = True
|
|
else:
|
|
h_pad = 0
|
|
if d < patch_size[2]:
|
|
d_pad = patch_size[2]-d
|
|
add_pad = True
|
|
else:
|
|
d_pad = 0
|
|
wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
|
|
hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
|
|
dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
|
|
if add_pad:
|
|
image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
|
|
|
|
ww,hh,dd = image.shape
|
|
|
|
sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
|
|
sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
|
|
sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
|
|
score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
|
|
cnt = np.zeros(image.shape).astype(np.float32)
|
|
|
|
for x in range(0, sx):
|
|
xs = min(stride_xy*x, ww-patch_size[0])
|
|
for y in range(0, sy):
|
|
ys = min(stride_xy * y,hh-patch_size[1])
|
|
for z in range(0, sz):
|
|
zs = min(stride_z * z, dd-patch_size[2])
|
|
test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
|
|
test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32)
|
|
|
|
|
|
test_patch = torch.from_numpy(test_patch).cuda()
|
|
|
|
|
|
for model in model_array:
|
|
output = model(test_patch)
|
|
y_temp = F.softmax(output, dim=1)
|
|
y_temp = y_temp.cpu().data.numpy()
|
|
y += y_temp[0,:,:,:,:]
|
|
y /= len(model_array)
|
|
score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
|
|
= score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
|
|
cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
|
|
= cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
|
|
score_map = score_map/np.expand_dims(cnt,axis=0)
|
|
|
|
label_map = np.argmax(score_map, axis = 0)
|
|
if add_pad:
|
|
label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
|
|
score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
|
|
return label_map, score_map
|
|
|
|
def calculate_metric_percase(pred, gt):
|
|
dice = metric.binary.dc(pred, gt)
|
|
jc = metric.binary.jc(pred, gt)
|
|
hd = metric.binary.hd95(pred, gt)
|
|
asd = metric.binary.asd(pred, gt)
|
|
|
|
return dice, jc, hd, asd
|
|
|
|
|
|
class RailNetSystem(nn.Module, PyTorchModelHubMixin):
|
|
def __init__(self, n_channels: int, n_classes: int, normalization: str):
|
|
super().__init__()
|
|
|
|
self.num_classes = 2
|
|
|
|
|
|
self.net_roi = VNet_roi(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=False).cuda()
|
|
|
|
|
|
self.model_array = []
|
|
for i in range(4):
|
|
if i < 2:
|
|
model = VNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda()
|
|
else:
|
|
model = ResVNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda()
|
|
self.model_array.append(model)
|
|
|
|
def load_weights(self, weight_dir=".", from_hub=False, repo_id=None):
|
|
def load(file_name):
|
|
if from_hub:
|
|
return hf_hub_download(repo_id=repo_id, filename=f"model weights/{file_name}")
|
|
else:
|
|
return os.path.join(weight_dir, "model weights", file_name)
|
|
|
|
self.net_roi.load_state_dict(torch.load(load("roi_best_model.pth"), map_location="cuda", weights_only=True))
|
|
self.net_roi.eval()
|
|
|
|
model_files = [
|
|
"rail_0_iter_7995_best.pth",
|
|
"rail_1_iter_7995_best.pth",
|
|
"rail_2_iter_7995_best.pth",
|
|
"rail_3_iter_7995_best.pth",
|
|
]
|
|
for i, file in enumerate(model_files):
|
|
self.model_array[i].load_state_dict(torch.load(load(file), map_location="cuda", weights_only=True))
|
|
self.model_array[i].eval()
|
|
|
|
def forward(self, image, label, save_path="./output", name="case"):
|
|
if not os.path.exists(save_path):
|
|
os.makedirs(save_path)
|
|
nib.save(nib.Nifti1Image(image.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_img.nii.gz"))
|
|
|
|
w, h, d = image.shape
|
|
|
|
image, x_min, x_max, y_min, y_max, z_min, z_max = roi_extraction(image, self.net_roi, name)
|
|
|
|
prediction, _ = test_single_case_array(self.model_array, image, stride_xy=64, stride_z=32, patch_size=(112, 112, 80), num_classes=self.num_classes)
|
|
|
|
prediction = morphology.remove_small_objects(prediction.astype(bool), 3000, connectivity=3).astype(float)
|
|
|
|
new_prediction = np.zeros((w, h, d))
|
|
new_prediction[x_min:x_max, y_min:y_max, z_min:z_max] = prediction
|
|
|
|
dice, jc, hd, asd = calculate_metric_percase(new_prediction, label[:])
|
|
|
|
nib.save(nib.Nifti1Image(new_prediction.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_pred.nii.gz"))
|
|
|
|
return new_prediction, dice, jc, hd, asd
|
|
|