Spaces:
Sleeping
Sleeping
| import functools | |
| from tensorflow.keras import layers | |
| from .attentions import RDCAB | |
| from .misc_gating import ResidualSplitHeadMultiAxisGmlpLayer | |
| Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same") | |
| def BottleneckBlock( | |
| features: int, | |
| block_size, | |
| grid_size, | |
| num_groups: int = 1, | |
| block_gmlp_factor: int = 2, | |
| grid_gmlp_factor: int = 2, | |
| input_proj_factor: int = 2, | |
| channels_reduction: int = 4, | |
| dropout_rate: float = 0.0, | |
| use_bias: bool = True, | |
| name: str = "bottleneck_block", | |
| ): | |
| """The bottleneck block consisting of multi-axis gMLP block and RDCAB.""" | |
| def apply(x): | |
| # input projection | |
| x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_input_proj")(x) | |
| shortcut_long = x | |
| for i in range(num_groups): | |
| x = ResidualSplitHeadMultiAxisGmlpLayer( | |
| grid_size=grid_size, | |
| block_size=block_size, | |
| grid_gmlp_factor=grid_gmlp_factor, | |
| block_gmlp_factor=block_gmlp_factor, | |
| input_proj_factor=input_proj_factor, | |
| use_bias=use_bias, | |
| dropout_rate=dropout_rate, | |
| name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}", | |
| )(x) | |
| # Channel-mixing part, which provides within-patch communication. | |
| x = RDCAB( | |
| num_channels=features, | |
| reduction=channels_reduction, | |
| use_bias=use_bias, | |
| name=f"{name}_channel_attention_block_1_{i}", | |
| )(x) | |
| # long skip-connect | |
| x = x + shortcut_long | |
| return x | |
| return apply | |