Spaces:
Runtime error
Runtime error
| import functools | |
| import tensorflow as tf | |
| from tensorflow.keras import backend as K | |
| from tensorflow.keras import layers | |
| from ..layers import BlockImages, SwapAxes, UnblockImages | |
| from .block_gating import BlockGmlpLayer | |
| from .grid_gating import GridGmlpLayer | |
| Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same") | |
| Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same") | |
| ConvT_up = functools.partial( | |
| layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same" | |
| ) | |
| Conv_down = functools.partial( | |
| layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same" | |
| ) | |
| def ResidualSplitHeadMultiAxisGmlpLayer( | |
| block_size, | |
| grid_size, | |
| block_gmlp_factor: int = 2, | |
| grid_gmlp_factor: int = 2, | |
| input_proj_factor: int = 2, | |
| use_bias: bool = True, | |
| dropout_rate: float = 0.0, | |
| name: str = "residual_split_head_maxim", | |
| ): | |
| """The multi-axis gated MLP block.""" | |
| def apply(x): | |
| shortcut = x | |
| n, h, w, num_channels = ( | |
| K.int_shape(x)[0], | |
| K.int_shape(x)[1], | |
| K.int_shape(x)[2], | |
| K.int_shape(x)[3], | |
| ) | |
| x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x) | |
| x = layers.Dense( | |
| int(num_channels) * input_proj_factor, | |
| use_bias=use_bias, | |
| name=f"{name}_in_project", | |
| )(x) | |
| x = tf.nn.gelu(x, approximate=True) | |
| u, v = tf.split(x, 2, axis=-1) | |
| # GridGMLPLayer | |
| u = GridGmlpLayer( | |
| grid_size=grid_size, | |
| factor=grid_gmlp_factor, | |
| use_bias=use_bias, | |
| dropout_rate=dropout_rate, | |
| name=f"{name}_GridGmlpLayer", | |
| )(u) | |
| # BlockGMLPLayer | |
| v = BlockGmlpLayer( | |
| block_size=block_size, | |
| factor=block_gmlp_factor, | |
| use_bias=use_bias, | |
| dropout_rate=dropout_rate, | |
| name=f"{name}_BlockGmlpLayer", | |
| )(v) | |
| x = tf.concat([u, v], axis=-1) | |
| x = layers.Dense( | |
| num_channels, | |
| use_bias=use_bias, | |
| name=f"{name}_out_project", | |
| )(x) | |
| x = layers.Dropout(dropout_rate)(x) | |
| x = x + shortcut | |
| return x | |
| return apply | |
| def GetSpatialGatingWeights( | |
| features: int, | |
| block_size, | |
| grid_size, | |
| input_proj_factor: int = 2, | |
| dropout_rate: float = 0.0, | |
| use_bias: bool = True, | |
| name: str = "spatial_gating", | |
| ): | |
| """Get gating weights for cross-gating MLP block.""" | |
| def apply(x): | |
| n, h, w, num_channels = ( | |
| K.int_shape(x)[0], | |
| K.int_shape(x)[1], | |
| K.int_shape(x)[2], | |
| K.int_shape(x)[3], | |
| ) | |
| # input projection | |
| x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x) | |
| x = layers.Dense( | |
| num_channels * input_proj_factor, | |
| use_bias=use_bias, | |
| name=f"{name}_in_project", | |
| )(x) | |
| x = tf.nn.gelu(x, approximate=True) | |
| u, v = tf.split(x, 2, axis=-1) | |
| # Get grid MLP weights | |
| gh, gw = grid_size | |
| fh, fw = h // gh, w // gw | |
| u = BlockImages()(u, patch_size=(fh, fw)) | |
| dim_u = K.int_shape(u)[-3] | |
| u = SwapAxes()(u, -1, -3) | |
| u = layers.Dense(dim_u, use_bias=use_bias, name=f"{name}_Dense_0")(u) | |
| u = SwapAxes()(u, -1, -3) | |
| u = UnblockImages()(u, grid_size=(gh, gw), patch_size=(fh, fw)) | |
| # Get Block MLP weights | |
| fh, fw = block_size | |
| gh, gw = h // fh, w // fw | |
| v = BlockImages()(v, patch_size=(fh, fw)) | |
| dim_v = K.int_shape(v)[-2] | |
| v = SwapAxes()(v, -1, -2) | |
| v = layers.Dense(dim_v, use_bias=use_bias, name=f"{name}_Dense_1")(v) | |
| v = SwapAxes()(v, -1, -2) | |
| v = UnblockImages()(v, grid_size=(gh, gw), patch_size=(fh, fw)) | |
| x = tf.concat([u, v], axis=-1) | |
| x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project")(x) | |
| x = layers.Dropout(dropout_rate)(x) | |
| return x | |
| return apply | |
| def CrossGatingBlock( | |
| features: int, | |
| block_size, | |
| grid_size, | |
| dropout_rate: float = 0.0, | |
| input_proj_factor: int = 2, | |
| upsample_y: bool = True, | |
| use_bias: bool = True, | |
| name: str = "cross_gating", | |
| ): | |
| """Cross-gating MLP block.""" | |
| def apply(x, y): | |
| # Upscale Y signal, y is the gating signal. | |
| if upsample_y: | |
| y = ConvT_up( | |
| filters=features, use_bias=use_bias, name=f"{name}_ConvTranspose_0" | |
| )(y) | |
| x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_Conv_0")(x) | |
| n, h, w, num_channels = ( | |
| K.int_shape(x)[0], | |
| K.int_shape(x)[1], | |
| K.int_shape(x)[2], | |
| K.int_shape(x)[3], | |
| ) | |
| y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y) | |
| shortcut_x = x | |
| shortcut_y = y | |
| # Get gating weights from X | |
| x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_x")(x) | |
| x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_x")(x) | |
| x = tf.nn.gelu(x, approximate=True) | |
| gx = GetSpatialGatingWeights( | |
| features=num_channels, | |
| block_size=block_size, | |
| grid_size=grid_size, | |
| dropout_rate=dropout_rate, | |
| use_bias=use_bias, | |
| name=f"{name}_SplitHeadMultiAxisGating_x", | |
| )(x) | |
| # Get gating weights from Y | |
| y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_y")(y) | |
| y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_y")(y) | |
| y = tf.nn.gelu(y, approximate=True) | |
| gy = GetSpatialGatingWeights( | |
| features=num_channels, | |
| block_size=block_size, | |
| grid_size=grid_size, | |
| dropout_rate=dropout_rate, | |
| use_bias=use_bias, | |
| name=f"{name}_SplitHeadMultiAxisGating_y", | |
| )(y) | |
| # Apply cross gating: X = X * GY, Y = Y * GX | |
| y = y * gx | |
| y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_y")(y) | |
| y = layers.Dropout(dropout_rate)(y) | |
| y = y + shortcut_y | |
| x = x * gy # gating x using y | |
| x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_x")(x) | |
| x = layers.Dropout(dropout_rate)(x) | |
| x = x + y + shortcut_x # get all aggregated signals | |
| return x, y | |
| return apply | |