Spaces:
Runtime error
Runtime error
| import functools | |
| import tensorflow as tf | |
| from tensorflow.keras import layers | |
| from .others import MlpBlock | |
| Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same") | |
| Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same") | |
| def CALayer( | |
| num_channels: int, | |
| reduction: int = 4, | |
| use_bias: bool = True, | |
| name: str = "channel_attention", | |
| ): | |
| """Squeeze-and-excitation block for channel attention. | |
| ref: https://arxiv.org/abs/1709.01507 | |
| """ | |
| def apply(x): | |
| # 2D global average pooling | |
| y = layers.GlobalAvgPool2D(keepdims=True)(x) | |
| # Squeeze (in Squeeze-Excitation) | |
| y = Conv1x1( | |
| filters=num_channels // reduction, use_bias=use_bias, name=f"{name}_Conv_0" | |
| )(y) | |
| y = tf.nn.relu(y) | |
| # Excitation (in Squeeze-Excitation) | |
| y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y) | |
| y = tf.nn.sigmoid(y) | |
| return x * y | |
| return apply | |
| def RCAB( | |
| num_channels: int, | |
| reduction: int = 4, | |
| lrelu_slope: float = 0.2, | |
| use_bias: bool = True, | |
| name: str = "residual_ca", | |
| ): | |
| """Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer.""" | |
| def apply(x): | |
| shortcut = x | |
| x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x) | |
| x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv1")(x) | |
| x = tf.nn.leaky_relu(x, alpha=lrelu_slope) | |
| x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv2")(x) | |
| x = CALayer( | |
| num_channels=num_channels, | |
| reduction=reduction, | |
| use_bias=use_bias, | |
| name=f"{name}_channel_attention", | |
| )(x) | |
| return x + shortcut | |
| return apply | |
| def RDCAB( | |
| num_channels: int, | |
| reduction: int = 16, | |
| use_bias: bool = True, | |
| dropout_rate: float = 0.0, | |
| name: str = "rdcab", | |
| ): | |
| """Residual dense channel attention block. Used in Bottlenecks.""" | |
| def apply(x): | |
| y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x) | |
| y = MlpBlock( | |
| mlp_dim=num_channels, | |
| dropout_rate=dropout_rate, | |
| use_bias=use_bias, | |
| name=f"{name}_channel_mixing", | |
| )(y) | |
| y = CALayer( | |
| num_channels=num_channels, | |
| reduction=reduction, | |
| use_bias=use_bias, | |
| name=f"{name}_channel_attention", | |
| )(y) | |
| x = x + y | |
| return x | |
| return apply | |
| def SAM( | |
| num_channels: int, | |
| output_channels: int = 3, | |
| use_bias: bool = True, | |
| name: str = "sam", | |
| ): | |
| """Supervised attention module for multi-stage training. | |
| Introduced by MPRNet [CVPR2021]: https://github.com/swz30/MPRNet | |
| """ | |
| def apply(x, x_image): | |
| """Apply the SAM module to the input and num_channels. | |
| Args: | |
| x: the output num_channels from UNet decoder with shape (h, w, c) | |
| x_image: the input image with shape (h, w, 3) | |
| Returns: | |
| A tuple of tensors (x1, image) where (x1) is the sam num_channels used for the | |
| next stage, and (image) is the output restored image at current stage. | |
| """ | |
| # Get num_channels | |
| x1 = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x) | |
| # Output restored image X_s | |
| if output_channels == 3: | |
| image = ( | |
| Conv3x3( | |
| filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1" | |
| )(x) | |
| + x_image | |
| ) | |
| else: | |
| image = Conv3x3( | |
| filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1" | |
| )(x) | |
| # Get attention maps for num_channels | |
| x2 = tf.nn.sigmoid( | |
| Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_2")(image) | |
| ) | |
| # Get attended feature maps | |
| x1 = x1 * x2 | |
| # Residual connection | |
| x1 = x1 + x | |
| return x1, image | |
| return apply | |