danhtran2mind commited on
Commit
e67a29c
·
verified ·
1 Parent(s): 280808d

Update models/transformer_gray2color.py

Browse files
Files changed (1) hide show
  1. models/transformer_gray2color.py +0 -7
models/transformer_gray2color.py CHANGED
@@ -3,20 +3,16 @@ from tensorflow.keras.layers import Input, Dense, LayerNormalization, Dropout, M
3
  from tensorflow.keras.models import Model
4
  from tensorflow.keras.optimizers import Adam
5
 
6
- # Set float32 policy
7
  tf.keras.mixed_precision.set_global_policy('float32')
8
 
9
- # Define Transformer model
10
  def transformer_model(input_shape=(1024, 1024, 1), patch_size=8, d_model=32, num_heads=4, ff_dim=64, num_layers=2, dropout_rate=0.1):
11
  HEIGHT, WIDTH, _ = input_shape
12
  num_patches = (HEIGHT // patch_size) * (WIDTH // patch_size)
13
 
14
  inputs = Input(shape=input_shape)
15
- # Patch extraction
16
  x = Conv2D(d_model, (patch_size, patch_size), strides=(patch_size, patch_size), padding='valid')(inputs)
17
  x = Reshape((num_patches, d_model))(x)
18
 
19
- # Transformer layers
20
  for _ in range(num_layers):
21
  attn_output = MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)(x, x)
22
  attn_output = Dropout(dropout_rate)(attn_output)
@@ -28,7 +24,6 @@ def transformer_model(input_shape=(1024, 1024, 1), patch_size=8, d_model=32, num
28
  x = Add()([x, ff_output])
29
  x = LayerNormalization(epsilon=1e-6)(x)
30
 
31
- # Decoder: Reconstruct image
32
  x = Dense(2)(x)
33
  x = Reshape((HEIGHT // patch_size, WIDTH // patch_size, 2))(x)
34
  x = UpSampling2D(size=(patch_size, patch_size), interpolation='bilinear')(x)
@@ -37,9 +32,7 @@ def transformer_model(input_shape=(1024, 1024, 1), patch_size=8, d_model=32, num
37
  return Model(inputs, outputs)
38
 
39
  if __name__ == "__main__":
40
- # Define constants
41
  HEIGHT, WIDTH = 1024, 1024
42
- # Instantiate and compile the model
43
  model = transformer_model(input_shape=(HEIGHT, WIDTH, 1), patch_size=8, d_model=32, num_heads=4, ff_dim=64, num_layers=2)
44
  model.summary()
45
  model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
 
3
  from tensorflow.keras.models import Model
4
  from tensorflow.keras.optimizers import Adam
5
 
 
6
  tf.keras.mixed_precision.set_global_policy('float32')
7
 
 
8
  def transformer_model(input_shape=(1024, 1024, 1), patch_size=8, d_model=32, num_heads=4, ff_dim=64, num_layers=2, dropout_rate=0.1):
9
  HEIGHT, WIDTH, _ = input_shape
10
  num_patches = (HEIGHT // patch_size) * (WIDTH // patch_size)
11
 
12
  inputs = Input(shape=input_shape)
 
13
  x = Conv2D(d_model, (patch_size, patch_size), strides=(patch_size, patch_size), padding='valid')(inputs)
14
  x = Reshape((num_patches, d_model))(x)
15
 
 
16
  for _ in range(num_layers):
17
  attn_output = MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)(x, x)
18
  attn_output = Dropout(dropout_rate)(attn_output)
 
24
  x = Add()([x, ff_output])
25
  x = LayerNormalization(epsilon=1e-6)(x)
26
 
 
27
  x = Dense(2)(x)
28
  x = Reshape((HEIGHT // patch_size, WIDTH // patch_size, 2))(x)
29
  x = UpSampling2D(size=(patch_size, patch_size), interpolation='bilinear')(x)
 
32
  return Model(inputs, outputs)
33
 
34
  if __name__ == "__main__":
 
35
  HEIGHT, WIDTH = 1024, 1024
 
36
  model = transformer_model(input_shape=(HEIGHT, WIDTH, 1), patch_size=8, d_model=32, num_heads=4, ff_dim=64, num_layers=2)
37
  model.summary()
38
  model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())