#https://github.com/wolny/pytorch-3dunet/blob/master/pytorch3dunet/unet3d/buildingblocks.py # MIT License # Copyright (c) 2018 Adrian Wolny # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. from functools import partial import torch from torch import nn as nn from torch.nn import functional as F # from pytorch3dunet.unet3d.se import ChannelSELayer3D, ChannelSpatialSELayer3D, SpatialSELayer3D def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, dropout_prob, is3d): """ Create a list of modules with together constitute a single conv layer with non-linearity and optional batchnorm/groupnorm. Args: in_channels (int): number of input channels out_channels (int): number of output channels kernel_size(int or tuple): size of the convolving kernel order (string): order of things, e.g. 'cr' -> conv + ReLU 'gcr' -> groupnorm + conv + ReLU 'cl' -> conv + LeakyReLU 'ce' -> conv + ELU 'bcr' -> batchnorm + conv + ReLU 'cbrd' -> conv + batchnorm + ReLU + dropout 'cbrD' -> conv + batchnorm + ReLU + dropout2d num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding added to all three sides of the input dropout_prob (float): dropout probability is3d (bool): is3d (bool): if True use Conv3d, otherwise use Conv2d Return: list of tuple (name, module) """ assert 'c' in order, "Conv layer MUST be present" assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' modules = [] for i, char in enumerate(order): if char == 'r': modules.append(('ReLU', nn.ReLU(inplace=True))) elif char == 'l': modules.append(('LeakyReLU', nn.LeakyReLU(inplace=True))) elif char == 'e': modules.append(('ELU', nn.ELU(inplace=True))) elif char == 'c': # add learnable bias only in the absence of batchnorm/groupnorm bias = not ('g' in order or 'b' in order) if is3d: conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) else: conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) modules.append(('conv', conv)) elif char == 'g': is_before_conv = i < order.index('c') if is_before_conv: num_channels = in_channels else: num_channels = out_channels # use only one group if the given number of groups is greater than the number of channels if num_channels < num_groups: num_groups = 1 assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) elif char == 'b': is_before_conv = i < order.index('c') if is3d: bn = nn.BatchNorm3d else: bn = nn.BatchNorm2d if is_before_conv: modules.append(('batchnorm', bn(in_channels))) else: modules.append(('batchnorm', bn(out_channels))) elif char == 'd': modules.append(('dropout', nn.Dropout(p=dropout_prob))) elif char == 'D': modules.append(('dropout2d', nn.Dropout2d(p=dropout_prob))) else: raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c', 'd', 'D']") return modules class SingleConv(nn.Sequential): """ Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order of operations can be specified via the `order` parameter Args: in_channels (int): number of input channels out_channels (int): number of output channels kernel_size (int or tuple): size of the convolving kernel order (string): determines the order of layers, e.g. 'cr' -> conv + ReLU 'crg' -> conv + ReLU + groupnorm 'cl' -> conv + LeakyReLU 'ce' -> conv + ELU num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding dropout_prob (float): dropout probability, default 0.1 is3d (bool): if True use Conv3d, otherwise use Conv2d """ def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8, padding=1, dropout_prob=0.1, is3d=True): super(SingleConv, self).__init__() for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, dropout_prob, is3d): self.add_module(name, module) class DoubleConv(nn.Sequential): """ A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). We use (Conv3d+ReLU+GroupNorm3d) by default. This can be changed however by providing the 'order' argument, e.g. in order to change to Conv3d+BatchNorm3d+ELU use order='cbe'. Use padded convolutions to make sure that the output (H_out, W_out) is the same as (H_in, W_in), so that you don't have to crop in the decoder path. Args: in_channels (int): number of input channels out_channels (int): number of output channels encoder (bool): if True we're in the encoder path, otherwise we're in the decoder kernel_size (int or tuple): size of the convolving kernel order (string): determines the order of layers, e.g. 'cr' -> conv + ReLU 'crg' -> conv + ReLU + groupnorm 'cl' -> conv + LeakyReLU 'ce' -> conv + ELU num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding added to all three sides of the input upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 dropout_prob (float or tuple): dropout probability for each convolution, default 0.1 is3d (bool): if True use Conv3d instead of Conv2d layers """ def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr', num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True): super(DoubleConv, self).__init__() if encoder: # we're in the encoder path conv1_in_channels = in_channels if upscale == 1: conv1_out_channels = out_channels else: conv1_out_channels = out_channels // 2 if conv1_out_channels < in_channels: conv1_out_channels = in_channels conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels else: # we're in the decoder path, decrease the number of channels in the 1st convolution conv1_in_channels, conv1_out_channels = in_channels, out_channels conv2_in_channels, conv2_out_channels = out_channels, out_channels # check if dropout_prob is a tuple and if so # split it for different dropout probabilities for each convolution. if isinstance(dropout_prob, list) or isinstance(dropout_prob, tuple): dropout_prob1 = dropout_prob[0] dropout_prob2 = dropout_prob[1] else: dropout_prob1 = dropout_prob2 = dropout_prob # conv1 self.add_module('SingleConv1', SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups, padding=padding, dropout_prob=dropout_prob1, is3d=is3d)) # conv2 self.add_module('SingleConv2', SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups, padding=padding, dropout_prob=dropout_prob2, is3d=is3d)) class ResNetBlock(nn.Module): """ Residual block that can be used instead of standard DoubleConv in the Encoder module. Motivated by: https://arxiv.org/pdf/1706.00120.pdf Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. """ def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, is3d=True, **kwargs): super(ResNetBlock, self).__init__() if in_channels != out_channels: # conv1x1 for increasing the number of channels if is3d: self.conv1 = nn.Conv3d(in_channels, out_channels, 1) else: self.conv1 = nn.Conv2d(in_channels, out_channels, 1) else: self.conv1 = nn.Identity() self.conv2 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups, is3d=is3d) # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual n_order = order for c in 'rel': n_order = n_order.replace(c, '') self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, num_groups=num_groups, is3d=is3d) # create non-linearity separately if 'l' in order: self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) elif 'e' in order: self.non_linearity = nn.ELU(inplace=True) else: self.non_linearity = nn.ReLU(inplace=True) def forward(self, x): # apply first convolution to bring the number of channels to out_channels residual = self.conv1(x) out = self.conv2(x) out = self.conv3(out) out += residual out = self.non_linearity(out) return out class Encoder(nn.Module): """ A single module from the encoder path consisting of the optional max pooling layer (one may specify the MaxPool kernel_size to be different from the standard (2,2,2), e.g. if the volumetric data is anisotropic (make sure to use complementary scale_factor in the decoder path) followed by a basic module (DoubleConv or ResNetBlock). Args: in_channels (int): number of input channels out_channels (int): number of output channels conv_kernel_size (int or tuple): size of the convolving kernel apply_pooling (bool): if True use MaxPool3d before DoubleConv pool_kernel_size (int or tuple): the size of the window pool_type (str): pooling layer: 'max' or 'avg' basic_module(nn.Module): either ResNetBlock or DoubleConv conv_layer_order (string): determines the order of layers in `DoubleConv` module. See `DoubleConv` for more info. num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding added to all three sides of the input upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 dropout_prob (float or tuple): dropout probability, default 0.1 is3d (bool): use 3d or 2d convolutions/pooling operation """ def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr', num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True): super(Encoder, self).__init__() assert pool_type in ['max', 'avg'] if apply_pooling: if pool_type == 'max': if is3d: self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) else: self.pooling = nn.MaxPool2d(kernel_size=pool_kernel_size) else: if is3d: self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) else: self.pooling = nn.AvgPool2d(kernel_size=pool_kernel_size) else: self.pooling = None self.basic_module = basic_module(in_channels, out_channels, encoder=True, kernel_size=conv_kernel_size, order=conv_layer_order, num_groups=num_groups, padding=padding, upscale=upscale, dropout_prob=dropout_prob, is3d=is3d) def forward(self, x): if self.pooling is not None: x = self.pooling(x) x = self.basic_module(x) return x class Decoder(nn.Module): """ A single module for decoder path consisting of the upsampling layer (either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ResNetBlock). Args: in_channels (int): number of input channels out_channels (int): number of output channels conv_kernel_size (int or tuple): size of the convolving kernel scale_factor (int or tuple): used as the multiplier for the image H/W/D in case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation from the corresponding encoder basic_module(nn.Module): either ResNetBlock or DoubleConv conv_layer_order (string): determines the order of layers in `DoubleConv` module. See `DoubleConv` for more info. num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding added to all three sides of the input upsample (str): algorithm used for upsampling: InterpolateUpsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area' TransposeConvUpsampling: 'deconv' No upsampling: None Default: 'default' (chooses automatically) dropout_prob (float or tuple): dropout probability, default 0.1 """ def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=2, basic_module=DoubleConv, conv_layer_order='gcr', num_groups=8, padding=1, upsample='default', dropout_prob=0.1, is3d=True): super(Decoder, self).__init__() # perform concat joining per default concat = True # don't adapt channels after join operation adapt_channels = False if upsample is not None and upsample != 'none': if upsample == 'default': if basic_module == DoubleConv: upsample = 'nearest' # use nearest neighbor interpolation for upsampling concat = True # use concat joining adapt_channels = False # don't adapt channels elif basic_module == ResNetBlock: #or basic_module == ResNetBlockSE: upsample = 'deconv' # use deconvolution upsampling concat = False # use summation joining adapt_channels = True # adapt channels after joining # perform deconvolution upsampling if mode is deconv if upsample == 'deconv': self.upsampling = TransposeConvUpsampling(in_channels=in_channels, out_channels=out_channels, kernel_size=conv_kernel_size, scale_factor=scale_factor, is3d=is3d) else: self.upsampling = InterpolateUpsampling(mode=upsample) else: # no upsampling self.upsampling = NoUpsampling() # concat joining self.joining = partial(self._joining, concat=True) # perform joining operation self.joining = partial(self._joining, concat=concat) # adapt the number of in_channels for the ResNetBlock if adapt_channels is True: in_channels = out_channels self.basic_module = basic_module(in_channels, out_channels, encoder=False, kernel_size=conv_kernel_size, order=conv_layer_order, num_groups=num_groups, padding=padding, dropout_prob=dropout_prob, is3d=is3d) def forward(self, encoder_features, x): x = self.upsampling(encoder_features=encoder_features, x=x) x = self.joining(encoder_features, x) x = self.basic_module(x) return x @staticmethod def _joining(encoder_features, x, concat): if concat: return torch.cat((encoder_features, x), dim=1) else: return encoder_features + x def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, conv_upscale, dropout_prob, layer_order, num_groups, pool_kernel_size, is3d): # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` encoders = [] for i, out_feature_num in enumerate(f_maps): if i == 0: # apply conv_coord only in the first encoder if any encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, # skip pooling in the firs encoder basic_module=basic_module, conv_layer_order=layer_order, conv_kernel_size=conv_kernel_size, num_groups=num_groups, padding=conv_padding, upscale=conv_upscale, dropout_prob=dropout_prob, is3d=is3d) else: encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module, conv_layer_order=layer_order, conv_kernel_size=conv_kernel_size, num_groups=num_groups, pool_kernel_size=pool_kernel_size, padding=conv_padding, upscale=conv_upscale, dropout_prob=dropout_prob, is3d=is3d) encoders.append(encoder) return nn.ModuleList(encoders) def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, upsample, dropout_prob, is3d): # create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1` decoders = [] reversed_f_maps = list(reversed(f_maps[1:])) for i in range(len(reversed_f_maps) - 1): if basic_module == DoubleConv and upsample != 'deconv': in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] else: in_feature_num = reversed_f_maps[i] out_feature_num = reversed_f_maps[i + 1] decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module, conv_layer_order=layer_order, conv_kernel_size=conv_kernel_size, num_groups=num_groups, padding=conv_padding, upsample=upsample, dropout_prob=dropout_prob, is3d=is3d) decoders.append(decoder) return nn.ModuleList(decoders) class AbstractUpsampling(nn.Module): """ Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either interpolation or learned transposed convolution. """ def __init__(self, upsample): super(AbstractUpsampling, self).__init__() self.upsample = upsample def forward(self, encoder_features, x): # get the spatial dimensions of the output given the encoder_features output_size = encoder_features.size()[2:] # upsample the input and return return self.upsample(x, output_size) class InterpolateUpsampling(AbstractUpsampling): """ Args: mode (str): algorithm used for upsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest' used only if transposed_conv is False """ def __init__(self, mode='nearest'): upsample = partial(self._interpolate, mode=mode) super().__init__(upsample) @staticmethod def _interpolate(x, size, mode): return F.interpolate(x, size=size, mode=mode) class TransposeConvUpsampling(AbstractUpsampling): """ Args: in_channels (int): number of input channels for transposed conv used only if transposed_conv is True out_channels (int): number of output channels for transpose conv used only if transposed_conv is True kernel_size (int or tuple): size of the convolving kernel used only if transposed_conv is True scale_factor (int or tuple): stride of the convolution used only if transposed_conv is True is3d (bool): if True use ConvTranspose3d, otherwise use ConvTranspose2d """ class Upsample(nn.Module): """ Workaround the 'ValueError: requested an output size...' in the `_output_padding` method in transposed convolution. It performs transposed conv followed by the interpolation to the correct size if necessary. """ def __init__(self, conv_transposed, is3d): super().__init__() self.conv_transposed = conv_transposed self.is3d = is3d def forward(self, x, size): x = self.conv_transposed(x) return F.interpolate(x, size=size) def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=2, is3d=True): # make sure that the output size reverses the MaxPool3d from the corresponding encoder if is3d is True: conv_transposed = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, padding=1, bias=False) else: conv_transposed = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, padding=1, bias=False) upsample = self.Upsample(conv_transposed, is3d) super().__init__(upsample) class NoUpsampling(AbstractUpsampling): def __init__(self): super().__init__(self._no_upsampling) @staticmethod def _no_upsampling(x, size): return x