Spaces:
Sleeping
Sleeping
| import tensorflow as tf | |
| from tensorflow.keras import backend as K | |
| from tensorflow.keras import layers | |
| from ..layers import BlockImages, SwapAxes, UnblockImages | |
| def GridGatingUnit(use_bias: bool = True, name: str = "grid_gating_unit"): | |
| """A SpatialGatingUnit as defined in the gMLP paper. | |
| The 'spatial' dim is defined as the second last. | |
| If applied on other dims, you should swapaxes first. | |
| """ | |
| def apply(x): | |
| u, v = tf.split(x, 2, axis=-1) | |
| v = layers.LayerNormalization( | |
| epsilon=1e-06, name=f"{name}_intermediate_layernorm" | |
| )(v) | |
| n = K.int_shape(x)[-3] # get spatial dim | |
| v = SwapAxes()(v, -1, -3) | |
| v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v) | |
| v = SwapAxes()(v, -1, -3) | |
| return u * (v + 1.0) | |
| return apply | |
| def GridGmlpLayer( | |
| grid_size, | |
| use_bias: bool = True, | |
| factor: int = 2, | |
| dropout_rate: float = 0.0, | |
| name: str = "grid_gmlp", | |
| ): | |
| """Grid gMLP layer that performs global mixing of tokens.""" | |
| def apply(x): | |
| n, h, w, num_channels = ( | |
| K.int_shape(x)[0], | |
| K.int_shape(x)[1], | |
| K.int_shape(x)[2], | |
| K.int_shape(x)[3], | |
| ) | |
| gh, gw = grid_size | |
| fh, fw = h // gh, w // gw | |
| x = BlockImages()(x, patch_size=(fh, fw)) | |
| # gMLP1: Global (grid) mixing part, provides global grid communication. | |
| y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x) | |
| y = layers.Dense( | |
| num_channels * factor, | |
| use_bias=use_bias, | |
| name=f"{name}_in_project", | |
| )(y) | |
| y = tf.nn.gelu(y, approximate=True) | |
| y = GridGatingUnit(use_bias=use_bias, name=f"{name}_GridGatingUnit")(y) | |
| y = layers.Dense( | |
| num_channels, | |
| use_bias=use_bias, | |
| name=f"{name}_out_project", | |
| )(y) | |
| y = layers.Dropout(dropout_rate)(y) | |
| x = x + y | |
| x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw)) | |
| return x | |
| return apply | |