""" U-shaped DISCO Neural Operator """ from typing import List, Tuple import torch import torch.nn as nn from torch.nn import functional as F from torch_harmonics_local.convolution import ( EquidistantDiscreteContinuousConv2d as DISCO2d, ) class UDNO(nn.Module): """ U-shaped DISCO Neural Operator in PyTorch """ def __init__( self, in_chans: int, out_chans: int, radius_cutoff: float, chans: int = 32, num_pool_layers: int = 4, drop_prob: float = 0.0, in_shape: Tuple[int, int] = (320, 320), kernel_shape: Tuple[int, int] = (3, 4), ): """ Parameters ---------- in_chans : int Number of channels in the input to the U-Net model. out_chans : int Number of channels in the output to the U-Net model. radius_cutoff : float Control the effective radius of the DISCO kernel. Values are between 0.0 and 1.0. The radius_cutoff is represented as a proportion of the normalized input space, to ensure that kernels are resolution invaraint. chans : int, optional Number of output channels of the first DISCO layer. Default is 32. num_pool_layers : int, optional Number of down-sampling and up-sampling layers. Default is 4. drop_prob : float, optional Dropout probability. Default is 0.0. in_shape : Tuple[int, int] Shape of the input to the UDNO. This is required to dynamically compile DISCO kernels for resolution invariance. kernel_shape : Tuple[int, int], optional Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3 rings and 4 anisotropic basis functions. Under the hood, each DISCO kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard 3x3 convolution kernel. Note: This is NOT kernel_size, as under the DISCO framework, kernels are dynamically compiled to support resolution invariance. """ super().__init__() assert len(in_shape) == 2, "Input shape must be 2D" self.in_chans = in_chans self.out_chans = out_chans self.chans = chans self.num_pool_layers = num_pool_layers self.drop_prob = drop_prob self.in_shape = in_shape self.kernel_shape = kernel_shape self.down_sample_layers = nn.ModuleList( [ DISCOBlock( in_chans, chans, radius_cutoff, drop_prob, in_shape, kernel_shape, ) ] ) ch = chans shape = (in_shape[0] // 2, in_shape[1] // 2) radius_cutoff = radius_cutoff * 2 for _ in range(num_pool_layers - 1): self.down_sample_layers.append( DISCOBlock( ch, ch * 2, radius_cutoff, drop_prob, in_shape=shape, kernel_shape=kernel_shape, ) ) ch *= 2 shape = (shape[0] // 2, shape[1] // 2) radius_cutoff *= 2 # test commit self.bottleneck = DISCOBlock( ch, ch * 2, radius_cutoff, drop_prob, in_shape=shape, kernel_shape=kernel_shape, ) self.up = nn.ModuleList() self.up_transpose = nn.ModuleList() for _ in range(num_pool_layers - 1): self.up_transpose.append( TransposeDISCOBlock( ch * 2, ch, radius_cutoff, in_shape=shape, kernel_shape=kernel_shape, ) ) shape = (shape[0] * 2, shape[1] * 2) radius_cutoff /= 2 self.up.append( DISCOBlock( ch * 2, ch, radius_cutoff, drop_prob, in_shape=shape, kernel_shape=kernel_shape, ) ) ch //= 2 self.up_transpose.append( TransposeDISCOBlock( ch * 2, ch, radius_cutoff, in_shape=shape, kernel_shape=kernel_shape, ) ) shape = (shape[0] * 2, shape[1] * 2) radius_cutoff /= 2 self.up.append( nn.Sequential( DISCOBlock( ch * 2, ch, radius_cutoff, drop_prob, in_shape=shape, kernel_shape=kernel_shape, ), nn.Conv2d( ch, self.out_chans, kernel_size=1, stride=1 ), # 1x1 conv is always res-invariant (pixel wise channel transformation) ) ) def forward(self, image: torch.Tensor) -> torch.Tensor: """ Parameters ---------- image : torch.Tensor Input 4D tensor of shape `(N, in_chans, H, W)`. Returns ------- torch.Tensor Output tensor of shape `(N, out_chans, H, W)`. """ stack = [] output = image # apply down-sampling layers for layer in self.down_sample_layers: output = layer(output) stack.append(output) output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) output = self.bottleneck(output) # apply up-sampling layers for transpose, disco in zip(self.up_transpose, self.up): downsample_layer = stack.pop() output = transpose(output) # reflect pad on the right/botton if needed to handle odd input dimensions padding = [0, 0, 0, 0] if output.shape[-1] != downsample_layer.shape[-1]: padding[1] = 1 # padding right if output.shape[-2] != downsample_layer.shape[-2]: padding[3] = 1 # padding bottom if torch.sum(torch.tensor(padding)) != 0: output = F.pad(output, padding, "reflect") output = torch.cat([output, downsample_layer], dim=1) output = disco(output) return output class DISCOBlock(nn.Module): """ A DISCO Block that consists of two DISCO layers each followed by instance normalization, LeakyReLU activation and dropout. """ def __init__( self, in_chans: int, out_chans: int, radius_cutoff: float, drop_prob: float, in_shape: Tuple[int, int], kernel_shape: Tuple[int, int] = (3, 4), ): """ Parameters ---------- in_chans : int Number of channels in the input. out_chans : int Number of channels in the output. radius_cutoff : float Control the effective radius of the DISCO kernel. Values are between 0.0 and 1.0. The radius_cutoff is represented as a proportion of the normalized input space, to ensure that kernels are resolution invaraint. in_shape : Tuple[int] Unbatched spatial 2D shape of the input to this block. Rrequired to dynamically compile DISCO kernels for resolution invariance. kernel_shape : Tuple[int, int], optional Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3 rings and 4 anisotropic basis functions. Under the hood, each DISCO kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard 3x3 convolution kernel. Note: This is NOT kernel_size, as under the DISCO framework, kernels are dynamically compiled to support resolution invariance. drop_prob : float Dropout probability. """ super().__init__() self.in_chans = in_chans self.out_chans = out_chans self.drop_prob = drop_prob self.layers = nn.Sequential( DISCO2d( in_chans, out_chans, kernel_shape=kernel_shape, in_shape=in_shape, bias=False, radius_cutoff=radius_cutoff, padding_mode="constant", ), nn.InstanceNorm2d(out_chans), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Dropout2d(drop_prob), DISCO2d( out_chans, out_chans, kernel_shape=kernel_shape, in_shape=in_shape, bias=False, radius_cutoff=radius_cutoff, padding_mode="constant", ), nn.InstanceNorm2d(out_chans), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Dropout2d(drop_prob), ) def forward(self, image: torch.Tensor) -> torch.Tensor: """ Parameters ---------- image : ndarray Input 4D tensor of shape `(N, in_chans, H, W)`. Returns ------- ndarray Output tensor of shape `(N, out_chans, H, W)`. """ return self.layers(image) class TransposeDISCOBlock(nn.Module): """ A transpose DISCO Block that consists of an up-sampling layer followed by a DISCO layer, instance normalization, and LeakyReLU activation. """ def __init__( self, in_chans: int, out_chans: int, radius_cutoff: float, in_shape: Tuple[int, int], kernel_shape: Tuple[int, int] = (3, 4), ): """ Parameters ---------- in_chans : int Number of channels in the input. out_chans : int Number of channels in the output. radius_cutoff : float Control the effective radius of the DISCO kernel. Values are between 0.0 and 1.0. The radius_cutoff is represented as a proportion of the normalized input space, to ensure that kernels are resolution invaraint. in_shape : Tuple[int] Unbatched spatial 2D shape of the input to this block. Rrequired to dynamically compile DISCO kernels for resolution invariance. kernel_shape : Tuple[int, int], optional Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3 rings and 4 anisotropic basis functions. Under the hood, each DISCO kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard 3x3 convolution kernel. Note: This is NOT kernel_size, as under the DISCO framework, kernels are dynamically compiled to support resolution invariance """ super().__init__() self.in_chans = in_chans self.out_chans = out_chans self.layers = nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), DISCO2d( in_chans, out_chans, kernel_shape=kernel_shape, in_shape=(2 * in_shape[0], 2 * in_shape[1]), bias=False, radius_cutoff=(radius_cutoff / 2), padding_mode="constant", ), nn.InstanceNorm2d(out_chans), nn.LeakyReLU(negative_slope=0.2, inplace=True), ) def forward(self, image: torch.Tensor) -> torch.Tensor: """ Parameters ---------- image : torch.Tensor Input 4D tensor of shape `(N, in_chans, H, W)`. Returns ------- torch.Tensor Output tensor of shape `(N, out_chans, H*2, W*2)`. """ return self.layers(image)