Update models/transformer_gray2color.py
Browse files
models/transformer_gray2color.py
CHANGED
@@ -3,20 +3,16 @@ from tensorflow.keras.layers import Input, Dense, LayerNormalization, Dropout, M
|
|
3 |
from tensorflow.keras.models import Model
|
4 |
from tensorflow.keras.optimizers import Adam
|
5 |
|
6 |
-
# Set float32 policy
|
7 |
tf.keras.mixed_precision.set_global_policy('float32')
|
8 |
|
9 |
-
# Define Transformer model
|
10 |
def transformer_model(input_shape=(1024, 1024, 1), patch_size=8, d_model=32, num_heads=4, ff_dim=64, num_layers=2, dropout_rate=0.1):
|
11 |
HEIGHT, WIDTH, _ = input_shape
|
12 |
num_patches = (HEIGHT // patch_size) * (WIDTH // patch_size)
|
13 |
|
14 |
inputs = Input(shape=input_shape)
|
15 |
-
# Patch extraction
|
16 |
x = Conv2D(d_model, (patch_size, patch_size), strides=(patch_size, patch_size), padding='valid')(inputs)
|
17 |
x = Reshape((num_patches, d_model))(x)
|
18 |
|
19 |
-
# Transformer layers
|
20 |
for _ in range(num_layers):
|
21 |
attn_output = MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)(x, x)
|
22 |
attn_output = Dropout(dropout_rate)(attn_output)
|
@@ -28,7 +24,6 @@ def transformer_model(input_shape=(1024, 1024, 1), patch_size=8, d_model=32, num
|
|
28 |
x = Add()([x, ff_output])
|
29 |
x = LayerNormalization(epsilon=1e-6)(x)
|
30 |
|
31 |
-
# Decoder: Reconstruct image
|
32 |
x = Dense(2)(x)
|
33 |
x = Reshape((HEIGHT // patch_size, WIDTH // patch_size, 2))(x)
|
34 |
x = UpSampling2D(size=(patch_size, patch_size), interpolation='bilinear')(x)
|
@@ -37,9 +32,7 @@ def transformer_model(input_shape=(1024, 1024, 1), patch_size=8, d_model=32, num
|
|
37 |
return Model(inputs, outputs)
|
38 |
|
39 |
if __name__ == "__main__":
|
40 |
-
# Define constants
|
41 |
HEIGHT, WIDTH = 1024, 1024
|
42 |
-
# Instantiate and compile the model
|
43 |
model = transformer_model(input_shape=(HEIGHT, WIDTH, 1), patch_size=8, d_model=32, num_heads=4, ff_dim=64, num_layers=2)
|
44 |
model.summary()
|
45 |
model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
|
|
|
3 |
from tensorflow.keras.models import Model
|
4 |
from tensorflow.keras.optimizers import Adam
|
5 |
|
|
|
6 |
tf.keras.mixed_precision.set_global_policy('float32')
|
7 |
|
|
|
8 |
def transformer_model(input_shape=(1024, 1024, 1), patch_size=8, d_model=32, num_heads=4, ff_dim=64, num_layers=2, dropout_rate=0.1):
|
9 |
HEIGHT, WIDTH, _ = input_shape
|
10 |
num_patches = (HEIGHT // patch_size) * (WIDTH // patch_size)
|
11 |
|
12 |
inputs = Input(shape=input_shape)
|
|
|
13 |
x = Conv2D(d_model, (patch_size, patch_size), strides=(patch_size, patch_size), padding='valid')(inputs)
|
14 |
x = Reshape((num_patches, d_model))(x)
|
15 |
|
|
|
16 |
for _ in range(num_layers):
|
17 |
attn_output = MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)(x, x)
|
18 |
attn_output = Dropout(dropout_rate)(attn_output)
|
|
|
24 |
x = Add()([x, ff_output])
|
25 |
x = LayerNormalization(epsilon=1e-6)(x)
|
26 |
|
|
|
27 |
x = Dense(2)(x)
|
28 |
x = Reshape((HEIGHT // patch_size, WIDTH // patch_size, 2))(x)
|
29 |
x = UpSampling2D(size=(patch_size, patch_size), interpolation='bilinear')(x)
|
|
|
32 |
return Model(inputs, outputs)
|
33 |
|
34 |
if __name__ == "__main__":
|
|
|
35 |
HEIGHT, WIDTH = 1024, 1024
|
|
|
36 |
model = transformer_model(input_shape=(HEIGHT, WIDTH, 1), patch_size=8, d_model=32, num_heads=4, ff_dim=64, num_layers=2)
|
37 |
model.summary()
|
38 |
model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
|