danhtran2mind commited on
Commit
ba81626
·
verified ·
1 Parent(s): e67a29c

Update models/autoencoder_gray2color.py

Browse files
Files changed (1) hide show
  1. models/autoencoder_gray2color.py +0 -10
models/autoencoder_gray2color.py CHANGED
@@ -3,10 +3,8 @@ from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, B
3
  from tensorflow.keras.models import Model
4
  from tensorflow.keras.optimizers import Adam
5
 
6
- # Set float32 policy
7
  tf.keras.mixed_precision.set_global_policy('float32')
8
 
9
- # Spatial Attention Layer
10
  class SpatialAttention(tf.keras.layers.Layer):
11
  def __init__(self, kernel_size=7, **kwargs):
12
  super(SpatialAttention, self).__init__(**kwargs)
@@ -25,16 +23,13 @@ class SpatialAttention(tf.keras.layers.Layer):
25
  config.update({'kernel_size': self.kernel_size})
26
  return config
27
 
28
- # Build Autoencoder
29
  def build_autoencoder(height, width):
30
  input_img = Input(shape=(height, width, 1))
31
- # Encoder
32
  x = Conv2D(96, (3, 3), activation='relu', padding='same')(input_img)
33
  x = BatchNormalization()(x)
34
  x = SpatialAttention()(x)
35
  x = MaxPooling2D((2, 2), padding='same')(x)
36
 
37
- # Residual Block 1
38
  residual = Conv2D(192, (1, 1), padding='same')(x)
39
  x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
40
  x = BatchNormalization()(x)
@@ -44,7 +39,6 @@ def build_autoencoder(height, width):
44
  x = SpatialAttention()(x)
45
  x = MaxPooling2D((2, 2), padding='same')(x)
46
 
47
- # Residual Block 2
48
  residual = Conv2D(384, (1, 1), padding='same')(x)
49
  x = Conv2D(384, (3, 3), activation='relu', padding='same')(x)
50
  x = BatchNormalization()(x)
@@ -54,13 +48,11 @@ def build_autoencoder(height, width):
54
  x = SpatialAttention()(x)
55
  encoded = MaxPooling2D((2, 2), padding='same')(x)
56
 
57
- # Decoder
58
  x = Conv2D(384, (3, 3), activation='relu', padding='same')(encoded)
59
  x = BatchNormalization()(x)
60
  x = SpatialAttention()(x)
61
  x = UpSampling2D((2, 2))(x)
62
 
63
- # Residual Block 3
64
  residual = Conv2D(192, (1, 1), padding='same')(x)
65
  x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
66
  x = BatchNormalization()(x)
@@ -79,9 +71,7 @@ def build_autoencoder(height, width):
79
  return Model(input_img, decoded)
80
 
81
  if __name__ == "__main__":
82
- # Define constants
83
  HEIGHT, WIDTH = 512, 512
84
- # Compile model
85
  autoencoder = build_autoencoder(HEIGHT, WIDTH)
86
  autoencoder.summary()
87
  autoencoder.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
 
3
  from tensorflow.keras.models import Model
4
  from tensorflow.keras.optimizers import Adam
5
 
 
6
  tf.keras.mixed_precision.set_global_policy('float32')
7
 
 
8
  class SpatialAttention(tf.keras.layers.Layer):
9
  def __init__(self, kernel_size=7, **kwargs):
10
  super(SpatialAttention, self).__init__(**kwargs)
 
23
  config.update({'kernel_size': self.kernel_size})
24
  return config
25
 
 
26
  def build_autoencoder(height, width):
27
  input_img = Input(shape=(height, width, 1))
 
28
  x = Conv2D(96, (3, 3), activation='relu', padding='same')(input_img)
29
  x = BatchNormalization()(x)
30
  x = SpatialAttention()(x)
31
  x = MaxPooling2D((2, 2), padding='same')(x)
32
 
 
33
  residual = Conv2D(192, (1, 1), padding='same')(x)
34
  x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
35
  x = BatchNormalization()(x)
 
39
  x = SpatialAttention()(x)
40
  x = MaxPooling2D((2, 2), padding='same')(x)
41
 
 
42
  residual = Conv2D(384, (1, 1), padding='same')(x)
43
  x = Conv2D(384, (3, 3), activation='relu', padding='same')(x)
44
  x = BatchNormalization()(x)
 
48
  x = SpatialAttention()(x)
49
  encoded = MaxPooling2D((2, 2), padding='same')(x)
50
 
 
51
  x = Conv2D(384, (3, 3), activation='relu', padding='same')(encoded)
52
  x = BatchNormalization()(x)
53
  x = SpatialAttention()(x)
54
  x = UpSampling2D((2, 2))(x)
55
 
 
56
  residual = Conv2D(192, (1, 1), padding='same')(x)
57
  x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
58
  x = BatchNormalization()(x)
 
71
  return Model(input_img, decoded)
72
 
73
  if __name__ == "__main__":
 
74
  HEIGHT, WIDTH = 512, 512
 
75
  autoencoder = build_autoencoder(HEIGHT, WIDTH)
76
  autoencoder.summary()
77
  autoencoder.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())