import os os.environ['KMP_DUPLICATE_LIB_OK']='True' os.environ["CUDA_VISIBLE_DEVICES"] = "0" import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import PyTorchModelHubMixin import numpy as np import nibabel as nib from skimage import morphology import math from scipy import ndimage from medpy import metric from huggingface_hub import hf_hub_download class ConvBlock(nn.Module): def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): super(ConvBlock, self).__init__() ops = [] for i in range(n_stages): if i == 0: input_channel = n_filters_in else: input_channel = n_filters_out ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) if normalization == 'batchnorm': ops.append(nn.BatchNorm3d(n_filters_out)) elif normalization == 'groupnorm': ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) elif normalization == 'instancenorm': ops.append(nn.InstanceNorm3d(n_filters_out)) elif normalization != 'none': assert False ops.append(nn.ReLU(inplace=True)) self.conv = nn.Sequential(*ops) def forward(self, x): x = self.conv(x) return x class DownsamplingConvBlock(nn.Module): def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): super(DownsamplingConvBlock, self).__init__() ops = [] if normalization != 'none': ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) if normalization == 'batchnorm': ops.append(nn.BatchNorm3d(n_filters_out)) elif normalization == 'groupnorm': ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) elif normalization == 'instancenorm': ops.append(nn.InstanceNorm3d(n_filters_out)) else: assert False else: ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) ops.append(nn.ReLU(inplace=True)) self.conv = nn.Sequential(*ops) def forward(self, x): x = self.conv(x) return x class UpsamplingDeconvBlock(nn.Module): def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): super(UpsamplingDeconvBlock, self).__init__() ops = [] if normalization != 'none': ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) if normalization == 'batchnorm': ops.append(nn.BatchNorm3d(n_filters_out)) elif normalization == 'groupnorm': ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) elif normalization == 'instancenorm': ops.append(nn.InstanceNorm3d(n_filters_out)) else: assert False else: ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) ops.append(nn.ReLU(inplace=True)) self.conv = nn.Sequential(*ops) def forward(self, x): x = self.conv(x) return x class Upsampling(nn.Module): def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): super(Upsampling, self).__init__() ops = [] ops.append(nn.Upsample(scale_factor=stride, mode='trilinear', align_corners=False)) ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1)) if normalization == 'batchnorm': ops.append(nn.BatchNorm3d(n_filters_out)) elif normalization == 'groupnorm': ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) elif normalization == 'instancenorm': ops.append(nn.InstanceNorm3d(n_filters_out)) elif normalization != 'none': assert False ops.append(nn.ReLU(inplace=True)) self.conv = nn.Sequential(*ops) def forward(self, x): x = self.conv(x) return x class ConnectNet(nn.Module): def __init__(self, in_channels, out_channels, input_size): super(ConnectNet, self).__init__() self.encoder = nn.Sequential( nn.Conv3d(in_channels, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool3d(kernel_size=2, stride=2), nn.Conv3d(128, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool3d(kernel_size=2, stride=2) ) self.decoder = nn.Sequential( nn.ConvTranspose3d(64, 128, kernel_size=2, stride=2), nn.ReLU(), nn.ConvTranspose3d(128, out_channels, kernel_size=2, stride=2), nn.Sigmoid() ) def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return decoded class VNet(nn.Module): def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False): super(VNet, self).__init__() self.has_dropout = has_dropout self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization) self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) self.dropout = nn.Dropout3d(p=0.5, inplace=False) self.__init_weight() def encoder(self, input): x1 = self.block_one(input) x1_dw = self.block_one_dw(x1) x2 = self.block_two(x1_dw) x2_dw = self.block_two_dw(x2) x3 = self.block_three(x2_dw) x3_dw = self.block_three_dw(x3) x4 = self.block_four(x3_dw) x4_dw = self.block_four_dw(x4) x5 = self.block_five(x4_dw) if self.has_dropout: x5 = self.dropout(x5) res = [x1, x2, x3, x4, x5] return res def decoder(self, features): x1 = features[0] x2 = features[1] x3 = features[2] x4 = features[3] x5 = features[4] x5_up = self.block_five_up(x5) x5_up = x5_up + x4 x6 = self.block_six(x5_up) x6_up = self.block_six_up(x6) x6_up = x6_up + x3 x7 = self.block_seven(x6_up) x7_up = self.block_seven_up(x7) x7_up = x7_up + x2 x8 = self.block_eight(x7_up) x8_up = self.block_eight_up(x8) x8_up = x8_up + x1 x9 = self.block_nine(x8_up) if self.has_dropout: x9 = self.dropout(x9) out = self.out_conv(x9) return out def forward(self, input, turnoff_drop=False): if turnoff_drop: has_dropout = self.has_dropout self.has_dropout = False features = self.encoder(input) out = self.decoder(features) if turnoff_drop: self.has_dropout = has_dropout return out def __init_weight(self): for m in self.modules(): if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d): torch.nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm3d): m.weight.data.fill_(1) m.bias.data.zero_() class VNet_roi(nn.Module): def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False): super(VNet_roi, self).__init__() self.has_dropout = has_dropout self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization) self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) self.dropout = nn.Dropout3d(p=0.5, inplace=False) # self.__init_weight() def encoder(self, input): x1 = self.block_one(input) x1_dw = self.block_one_dw(x1) x2 = self.block_two(x1_dw) x2_dw = self.block_two_dw(x2) x3 = self.block_three(x2_dw) x3_dw = self.block_three_dw(x3) x4 = self.block_four(x3_dw) x4_dw = self.block_four_dw(x4) x5 = self.block_five(x4_dw) # x5 = F.dropout3d(x5, p=0.5, training=True) if self.has_dropout: x5 = self.dropout(x5) res = [x1, x2, x3, x4, x5] return res def decoder(self, features): x1 = features[0] x2 = features[1] x3 = features[2] x4 = features[3] x5 = features[4] x5_up = self.block_five_up(x5) x5_up = x5_up + x4 x6 = self.block_six(x5_up) x6_up = self.block_six_up(x6) x6_up = x6_up + x3 x7 = self.block_seven(x6_up) x7_up = self.block_seven_up(x7) x7_up = x7_up + x2 x8 = self.block_eight(x7_up) x8_up = self.block_eight_up(x8) x8_up = x8_up + x1 x9 = self.block_nine(x8_up) # x9 = F.dropout3d(x9, p=0.5, training=True) if self.has_dropout: x9 = self.dropout(x9) out = self.out_conv(x9) return out def forward(self, input, turnoff_drop=False): if turnoff_drop: has_dropout = self.has_dropout self.has_dropout = False features = self.encoder(input) out = self.decoder(features) if turnoff_drop: self.has_dropout = has_dropout return out class ResVNet(nn.Module): def __init__(self, n_channels=1, n_classes=2, n_filters=16, normalization='instancenorm', has_dropout=False): super(ResVNet, self).__init__() self.resencoder = resnet34() self.has_dropout = has_dropout self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization) self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) if has_dropout: self.dropout = nn.Dropout3d(p=0.5) self.branchs = nn.ModuleList() for i in range(1): if has_dropout: seq = nn.Sequential( ConvBlock(1, n_filters, n_filters, normalization=normalization), nn.Dropout3d(p=0.5), nn.Conv3d(n_filters, n_classes, 1, padding=0) ) else: seq = nn.Sequential( ConvBlock(1, n_filters, n_filters, normalization=normalization), nn.Conv3d(n_filters, n_classes, 1, padding=0) ) self.branchs.append(seq) def encoder(self, input): x1 = self.block_one(input) x1_dw = self.block_one_dw(x1) x2 = self.block_two(x1_dw) x2_dw = self.block_two_dw(x2) x3 = self.block_three(x2_dw) x3_dw = self.block_three_dw(x3) x4 = self.block_four(x3_dw) x4_dw = self.block_four_dw(x4) x5 = self.block_five(x4_dw) if self.has_dropout: x5 = self.dropout(x5) res = [x1, x2, x3, x4, x5] return res def decoder(self, features): x1 = features[0] x2 = features[1] x3 = features[2] x4 = features[3] x5 = features[4] x5_up = self.block_five_up(x5) x5_up = x5_up + x4 x6 = self.block_six(x5_up) x6_up = self.block_six_up(x6) x6_up = x6_up + x3 x7 = self.block_seven(x6_up) x7_up = self.block_seven_up(x7) x7_up = x7_up + x2 x8 = self.block_eight(x7_up) x8_up = self.block_eight_up(x8) x8_up = x8_up + x1 x9 = self.block_nine(x8_up) out = self.out_conv(x9) return out def forward(self, input, turnoff_drop=False): if turnoff_drop: has_dropout = self.has_dropout self.has_dropout = False features = self.resencoder(input) out = self.decoder(features) if turnoff_drop: self.has_dropout = has_dropout return out __all__ = ['ResNet', 'resnet34'] def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) def conv3x3_bn_relu(in_planes, out_planes, stride=1): return nn.Sequential( conv3x3(in_planes, out_planes, stride), nn.InstanceNorm3d(out_planes), nn.ReLU() ) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=-1): super(BasicBlock, self).__init__() if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.InstanceNorm3d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.InstanceNorm3d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1): super(Bottleneck, self).__init__() width = int(planes * (base_width / 64.)) * groups self.conv1 = nn.Conv3d(inplanes, width, kernel_size=1, bias=False) self.bn1 = nn.InstanceNorm3d(width) self.conv2 = nn.Conv3d(width, width, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, groups=groups, bias=False) self.bn2 = nn.InstanceNorm3d(width) self.conv3 = nn.Conv3d(width, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.InstanceNorm3d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, layers, in_channel=1, width=1, groups=1, width_per_group=64, mid_dim=1024, low_dim=128, avg_down=False, deep_stem=False, head_type='mlp_head', layer4_dilation=1): super(ResNet, self).__init__() self.avg_down = avg_down self.inplanes = 16 * width self.base = int(16 * width) self.groups = groups self.base_width = width_per_group mid_dim = self.base * 8 * block.expansion if deep_stem: self.conv1 = nn.Sequential( conv3x3_bn_relu(in_channel, 32, stride=2), conv3x3_bn_relu(32, 32, stride=1), conv3x3(32, 64, stride=1) ) else: self.conv1 = nn.Conv3d(in_channel, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False) self.bn1 = nn.InstanceNorm3d(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, self.base*2, layers[0],stride=2) self.layer2 = self._make_layer(block, self.base * 4, layers[1], stride=2) self.layer3 = self._make_layer(block, self.base * 8, layers[2], stride=2) if layer4_dilation == 1: self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=2) elif layer4_dilation == 2: self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=1, dilation=2) else: raise NotImplementedError self.avgpool = nn.AvgPool3d(7, stride=1) def _make_layer(self, block, planes, blocks, stride=1, dilation=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: if self.avg_down: downsample = nn.Sequential( nn.AvgPool3d(kernel_size=stride, stride=stride), nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False), nn.InstanceNorm3d(planes * block.expansion), ) else: downsample = nn.Sequential( nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.InstanceNorm3d(planes * block.expansion), ) layers = [block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, dilation)] self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=dilation)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) #c2 = self.maxpool(x) c2 = self.layer1(x) c3 = self.layer2(c2) c4 = self.layer3(c3) c5 = self.layer4(c4) return [x,c2,c3,c4,c5] def resnet34(**kwargs): return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) def label_rescale(image_label, w_ori, h_ori, z_ori, flag): w_ori, h_ori, z_ori = int(w_ori), int(h_ori), int(z_ori) # resize label map (int) if flag == 'trilinear': teeth_ids = np.unique(image_label) image_label_ori = np.zeros((w_ori, h_ori, z_ori)) image_label = torch.from_numpy(image_label).cuda(0) for label_id in range(len(teeth_ids)): image_label_bn = (image_label == teeth_ids[label_id]).float() image_label_bn = image_label_bn[None, None, :, :, :] image_label_bn = torch.nn.functional.interpolate(image_label_bn, size=(w_ori, h_ori, z_ori), mode='trilinear', align_corners=False) image_label_bn = image_label_bn[0, 0, :, :, :] image_label_bn = image_label_bn.cpu().data.numpy() image_label_ori[image_label_bn > 0.5] = teeth_ids[label_id] image_label = image_label_ori if flag == 'nearest': image_label = torch.from_numpy(image_label).cuda(0) image_label = image_label[None, None, :, :, :].float() image_label = torch.nn.functional.interpolate(image_label, size=(w_ori, h_ori, z_ori), mode='nearest') image_label = image_label[0, 0, :, :, :].cpu().data.numpy() return image_label def img_crop(image_bbox): if image_bbox.sum() > 0: x_min = np.nonzero(image_bbox)[0].min() - 8 x_max = np.nonzero(image_bbox)[0].max() + 8 y_min = np.nonzero(image_bbox)[1].min() - 16 y_max = np.nonzero(image_bbox)[1].max() + 16 z_min = np.nonzero(image_bbox)[2].min() - 16 z_max = np.nonzero(image_bbox)[2].max() + 16 if x_min < 0: x_min = 0 if y_min < 0: y_min = 0 if z_min < 0: z_min = 0 if x_max > image_bbox.shape[0]: x_max = image_bbox.shape[0] if y_max > image_bbox.shape[1]: y_max = image_bbox.shape[1] if z_max > image_bbox.shape[2]: z_max = image_bbox.shape[2] if (x_max - x_min) % 16 != 0: x_max -= (x_max - x_min) % 16 if (y_max - y_min) % 16 != 0: y_max -= (y_max - y_min) % 16 if (z_max - z_min) % 16 != 0: z_max -= (z_max - z_min) % 16 if image_bbox.sum() == 0: x_min, x_max, y_min, y_max, z_min, z_max = -1, image_bbox.shape[0], 0, image_bbox.shape[1], 0, image_bbox.shape[ 2] return x_min, x_max, y_min, y_max, z_min, z_max def roi_extraction(image, net_roi, ids): w, h, d = image.shape # roi binary segmentation parameters, the input spacing is 0.4 mm print('---run the roi binary segmentation.') stride_xy = 32 stride_z = 16 patch_size_roi_stage = (112, 112, 80) label_roi = roi_detection(net_roi, image[0:w:2, 0:h:2, 0:d:2], stride_xy, stride_z, patch_size_roi_stage) # (400,400,200) print(label_roi.shape, np.max(label_roi)) label_roi = label_rescale(label_roi, w, h, d, 'trilinear') # (800,800,400) label_roi = morphology.remove_small_objects(label_roi.astype(bool), 5000, connectivity=3).astype(float) label_roi = ndimage.grey_dilation(label_roi, size=(5, 5, 5)) label_roi = morphology.remove_small_objects(label_roi.astype(bool), 400000, connectivity=3).astype( float) label_roi = ndimage.grey_erosion(label_roi, size=(5, 5, 5)) # crop image x_min, x_max, y_min, y_max, z_min, z_max = img_crop(label_roi) if x_min == -1: # non-foreground label whole_label = np.zeros((w, h, d)) return whole_label image = image[x_min:x_max, y_min:y_max, z_min:z_max] print("image shape(after roi): ", image.shape) return image, x_min, x_max, y_min, y_max, z_min, z_max def roi_detection(net, image, stride_xy, stride_z, patch_size): w, h, d = image.shape # (400,400,200) # if the size of image is less than patch_size, then padding it add_pad = False if w < patch_size[0]: w_pad = patch_size[0] - w add_pad = True else: w_pad = 0 if h < patch_size[1]: h_pad = patch_size[1] - h add_pad = True else: h_pad = 0 if d < patch_size[2]: d_pad = patch_size[2] - d add_pad = True else: d_pad = 0 wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2 hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2 dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2 if add_pad: image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) ww, hh, dd = image.shape sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 # 2 sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 # 2 sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 # 2 score_map = np.zeros((2,) + image.shape).astype(np.float32) cnt = np.zeros(image.shape).astype(np.float32) count = 0 for x in range(0, sx): xs = min(stride_xy * x, ww - patch_size[0]) for y in range(0, sy): ys = min(stride_xy * y, hh - patch_size[1]) for z in range(0, sz): zs = min(stride_z * z, dd - patch_size[2]) test_patch = image[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] test_patch = np.expand_dims(np.expand_dims(test_patch, axis=0), axis=0).astype( np.float32) test_patch = torch.from_numpy(test_patch).cuda(0) with torch.no_grad(): y1 = net(test_patch) # (1,2,256,256,160) y = F.softmax(y1, dim=1) # (1,2,256,256,160) y = y.cpu().data.numpy() y = y[0, :, :, :, :] # (2,256,256,160) score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ = score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + y # (2,400,400,200) cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ = cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 # (400,400,200) count = count + 1 score_map = score_map / np.expand_dims(cnt, axis=0) label_map = np.argmax(score_map, axis=0) # (400,400,200),0/1 if add_pad: label_map = label_map[wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] score_map = score_map[:, wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] return label_map def test_single_case_array(model_array, image=None, stride_xy=None, stride_z=None, patch_size=None, num_classes=1): w, h, d = image.shape # if the size of image is less than patch_size, then padding it add_pad = False if w < patch_size[0]: w_pad = patch_size[0]-w add_pad = True else: w_pad = 0 if h < patch_size[1]: h_pad = patch_size[1]-h add_pad = True else: h_pad = 0 if d < patch_size[2]: d_pad = patch_size[2]-d add_pad = True else: d_pad = 0 wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 if add_pad: image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) ww,hh,dd = image.shape sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) cnt = np.zeros(image.shape).astype(np.float32) for x in range(0, sx): xs = min(stride_xy*x, ww-patch_size[0]) for y in range(0, sy): ys = min(stride_xy * y,hh-patch_size[1]) for z in range(0, sz): zs = min(stride_z * z, dd-patch_size[2]) test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) test_patch = torch.from_numpy(test_patch).cuda() for model in model_array: output = model(test_patch) y_temp = F.softmax(output, dim=1) y_temp = y_temp.cpu().data.numpy() y += y_temp[0,:,:,:,:] y /= len(model_array) score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 score_map = score_map/np.expand_dims(cnt,axis=0) label_map = np.argmax(score_map, axis = 0) if add_pad: label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] return label_map, score_map def calculate_metric_percase(pred, gt): dice = metric.binary.dc(pred, gt) jc = metric.binary.jc(pred, gt) hd = metric.binary.hd95(pred, gt) asd = metric.binary.asd(pred, gt) return dice, jc, hd, asd class RailNetSystem(nn.Module, PyTorchModelHubMixin): def __init__(self, n_channels: int, n_classes: int, normalization: str): super().__init__() self.num_classes = 2 self.net_roi = VNet_roi(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=False).cuda() self.model_array = [] for i in range(4): if i < 2: model = VNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda() else: model = ResVNet(n_channels = n_channels, n_classes = n_classes, normalization = normalization, has_dropout=True).cuda() self.model_array.append(model) def load_weights(self, weight_dir=".", from_hub=False, repo_id=None): def load(file_name): if from_hub: return hf_hub_download(repo_id=repo_id, filename=f"model weights/{file_name}") else: return os.path.join(weight_dir, "model weights", file_name) self.net_roi.load_state_dict(torch.load(load("roi_best_model.pth"), map_location="cuda", weights_only=True)) self.net_roi.eval() model_files = [ "rail_0_iter_7995_best.pth", "rail_1_iter_7995_best.pth", "rail_2_iter_7995_best.pth", "rail_3_iter_7995_best.pth", ] for i, file in enumerate(model_files): self.model_array[i].load_state_dict(torch.load(load(file), map_location="cuda", weights_only=True)) self.model_array[i].eval() def forward(self, image, label, save_path="./output", name="case"): if not os.path.exists(save_path): os.makedirs(save_path) nib.save(nib.Nifti1Image(image.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_img.nii.gz")) w, h, d = image.shape image, x_min, x_max, y_min, y_max, z_min, z_max = roi_extraction(image, self.net_roi, name) prediction, _ = test_single_case_array(self.model_array, image, stride_xy=64, stride_z=32, patch_size=(112, 112, 80), num_classes=self.num_classes) prediction = morphology.remove_small_objects(prediction.astype(bool), 3000, connectivity=3).astype(float) new_prediction = np.zeros((w, h, d)) new_prediction[x_min:x_max, y_min:y_max, z_min:z_max] = prediction dice, jc, hd, asd = calculate_metric_percase(new_prediction, label[:]) nib.save(nib.Nifti1Image(new_prediction.astype(np.float32), np.eye(4)), os.path.join(save_path, f"{name}_pred.nii.gz")) return new_prediction, dice, jc, hd, asd