Truptidand commited on
Commit
de135fc
·
1 Parent(s): 7976c47

Upload GAN.ipynb

Browse files
Files changed (1) hide show
  1. GAN.ipynb +1514 -0
GAN.ipynb ADDED
@@ -0,0 +1,1514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "a3677b66",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import numpy as np\n",
11
+ "import pandas as pd\n",
12
+ "import matplotlib.pyplot as plt\n",
13
+ "import seaborn as sns\n",
14
+ "import os\n",
15
+ "import pickle\n",
16
+ "import time\n",
17
+ "import random"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": 8,
23
+ "id": "76ece7f8",
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "import PIL\n",
28
+ "from PIL import Image\n",
29
+ "import keras.backend as K\n",
30
+ "import tensorflow as tf\n",
31
+ "from tensorflow import keras\n",
32
+ "from keras.optimizers import Adam\n",
33
+ "from keras.models import Sequential\n",
34
+ "from keras import layers,Model,Input\n",
35
+ "from keras.layers import Lambda,Reshape,UpSampling2D,ReLU,add,ZeroPadding2D\n",
36
+ "from keras.layers import Activation,BatchNormalization,Concatenate,concatenate\n",
37
+ "from keras.layers import Dense,Conv2D,Flatten,Dropout,LeakyReLU\n",
38
+ "from keras.preprocessing.image import ImageDataGenerator"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "id": "b8980cd5",
44
+ "metadata": {},
45
+ "source": [
46
+ "### Conditioning Augmentation Network"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 3,
52
+ "id": "d3027cda",
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "# conditioned by the text.\n",
57
+ "def conditioning_augmentation(x):\n",
58
+ " \"\"\"The mean_logsigma passed as argument is converted into the text conditioning variable.\n",
59
+ "\n",
60
+ " Args:\n",
61
+ " x: The output of the text embedding passed through a FC layer with LeakyReLU non-linearity.\n",
62
+ "\n",
63
+ " Returns:\n",
64
+ " c: The text conditioning variable after computation.\n",
65
+ " \"\"\"\n",
66
+ " mean = x[:, :128]\n",
67
+ " log_sigma = x[:, 128:]\n",
68
+ "\n",
69
+ " stddev = tf.math.exp(log_sigma)\n",
70
+ " epsilon = K.random_normal(shape=K.constant((mean.shape[1], ), dtype='int32'))\n",
71
+ " c = mean + stddev * epsilon\n",
72
+ " return c\n",
73
+ "\n",
74
+ "def build_ca_network():\n",
75
+ " \"\"\"Builds the conditioning augmentation network.\n",
76
+ " \"\"\"\n",
77
+ " input_layer1 = Input(shape=(1024,)) #size of the vocabulary in the text data\n",
78
+ " mls = Dense(256)(input_layer1)\n",
79
+ " mls = LeakyReLU(alpha=0.2)(mls)\n",
80
+ " ca = Lambda(conditioning_augmentation)(mls)\n",
81
+ " return Model(inputs=[input_layer1], outputs=[ca]) "
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "markdown",
86
+ "id": "87340e8b",
87
+ "metadata": {},
88
+ "source": [
89
+ "### Stage 1 Generator Network"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 4,
95
+ "id": "c430524d",
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "def UpSamplingBlock(x, num_kernels):\n",
100
+ " \"\"\"An Upsample block with Upsampling2D, Conv2D, BatchNormalization and a ReLU activation.\n",
101
+ "\n",
102
+ " Args:\n",
103
+ " x: The preceding layer as input.\n",
104
+ " num_kernels: Number of kernels for the Conv2D layer.\n",
105
+ "\n",
106
+ " Returns:\n",
107
+ " x: The final activation layer after the Upsampling block.\n",
108
+ " \"\"\"\n",
109
+ " x = UpSampling2D(size=(2,2))(x)\n",
110
+ " x = Conv2D(num_kernels, kernel_size=(3,3), padding='same', strides=1, use_bias=False,\n",
111
+ " kernel_initializer='he_uniform')(x)\n",
112
+ " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x) #prevent from mode collapse\n",
113
+ " x = ReLU()(x)\n",
114
+ " return x\n",
115
+ "\n",
116
+ "\n",
117
+ "def build_stage1_generator():\n",
118
+ "\n",
119
+ " input_layer1 = Input(shape=(1024,))\n",
120
+ " ca = Dense(256)(input_layer1)\n",
121
+ " ca = LeakyReLU(alpha=0.2)(ca)\n",
122
+ "\n",
123
+ " # Obtain the conditioned text\n",
124
+ " c = Lambda(conditioning_augmentation)(ca)\n",
125
+ "\n",
126
+ " input_layer2 = Input(shape=(100,))\n",
127
+ " concat = Concatenate(axis=1)([c, input_layer2]) \n",
128
+ "\n",
129
+ " x = Dense(16384, use_bias=False)(concat) \n",
130
+ " x = ReLU()(x)\n",
131
+ " x = Reshape((4, 4, 1024), input_shape=(16384,))(x)\n",
132
+ "\n",
133
+ " x = UpSamplingBlock(x, 512) \n",
134
+ " x = UpSamplingBlock(x, 256)\n",
135
+ " x = UpSamplingBlock(x, 128)\n",
136
+ " x = UpSamplingBlock(x, 64) # upsampled our image to 64*64*3 \n",
137
+ "\n",
138
+ " x = Conv2D(3, kernel_size=3, padding='same', strides=1, use_bias=False,\n",
139
+ " kernel_initializer='he_uniform')(x)\n",
140
+ " x = Activation('tanh')(x)\n",
141
+ "\n",
142
+ " stage1_gen = Model(inputs=[input_layer1, input_layer2], outputs=[x, ca]) \n",
143
+ " return stage1_gen"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": 5,
149
+ "id": "0febcb4f",
150
+ "metadata": {},
151
+ "outputs": [
152
+ {
153
+ "name": "stdout",
154
+ "output_type": "stream",
155
+ "text": [
156
+ "Model: \"model\"\n",
157
+ "__________________________________________________________________________________________________\n",
158
+ " Layer (type) Output Shape Param # Connected to \n",
159
+ "==================================================================================================\n",
160
+ " input_1 (InputLayer) [(None, 1024)] 0 [] \n",
161
+ " \n",
162
+ " dense (Dense) (None, 256) 262400 ['input_1[0][0]'] \n",
163
+ " \n",
164
+ " leaky_re_lu (LeakyReLU) (None, 256) 0 ['dense[0][0]'] \n",
165
+ " \n",
166
+ " lambda (Lambda) (None, 128) 0 ['leaky_re_lu[0][0]'] \n",
167
+ " \n",
168
+ " input_2 (InputLayer) [(None, 100)] 0 [] \n",
169
+ " \n",
170
+ " concatenate (Concatenate) (None, 228) 0 ['lambda[0][0]', \n",
171
+ " 'input_2[0][0]'] \n",
172
+ " \n",
173
+ " dense_1 (Dense) (None, 16384) 3735552 ['concatenate[0][0]'] \n",
174
+ " \n",
175
+ " re_lu (ReLU) (None, 16384) 0 ['dense_1[0][0]'] \n",
176
+ " \n",
177
+ " reshape (Reshape) (None, 4, 4, 1024) 0 ['re_lu[0][0]'] \n",
178
+ " \n",
179
+ " up_sampling2d (UpSampling2D) (None, 8, 8, 1024) 0 ['reshape[0][0]'] \n",
180
+ " \n",
181
+ " conv2d (Conv2D) (None, 8, 8, 512) 4718592 ['up_sampling2d[0][0]'] \n",
182
+ " \n",
183
+ " batch_normalization (BatchNorm (None, 8, 8, 512) 2048 ['conv2d[0][0]'] \n",
184
+ " alization) \n",
185
+ " \n",
186
+ " re_lu_1 (ReLU) (None, 8, 8, 512) 0 ['batch_normalization[0][0]'] \n",
187
+ " \n",
188
+ " up_sampling2d_1 (UpSampling2D) (None, 16, 16, 512) 0 ['re_lu_1[0][0]'] \n",
189
+ " \n",
190
+ " conv2d_1 (Conv2D) (None, 16, 16, 256) 1179648 ['up_sampling2d_1[0][0]'] \n",
191
+ " \n",
192
+ " batch_normalization_1 (BatchNo (None, 16, 16, 256) 1024 ['conv2d_1[0][0]'] \n",
193
+ " rmalization) \n",
194
+ " \n",
195
+ " re_lu_2 (ReLU) (None, 16, 16, 256) 0 ['batch_normalization_1[0][0]'] \n",
196
+ " \n",
197
+ " up_sampling2d_2 (UpSampling2D) (None, 32, 32, 256) 0 ['re_lu_2[0][0]'] \n",
198
+ " \n",
199
+ " conv2d_2 (Conv2D) (None, 32, 32, 128) 294912 ['up_sampling2d_2[0][0]'] \n",
200
+ " \n",
201
+ " batch_normalization_2 (BatchNo (None, 32, 32, 128) 512 ['conv2d_2[0][0]'] \n",
202
+ " rmalization) \n",
203
+ " \n",
204
+ " re_lu_3 (ReLU) (None, 32, 32, 128) 0 ['batch_normalization_2[0][0]'] \n",
205
+ " \n",
206
+ " up_sampling2d_3 (UpSampling2D) (None, 64, 64, 128) 0 ['re_lu_3[0][0]'] \n",
207
+ " \n",
208
+ " conv2d_3 (Conv2D) (None, 64, 64, 64) 73728 ['up_sampling2d_3[0][0]'] \n",
209
+ " \n",
210
+ " batch_normalization_3 (BatchNo (None, 64, 64, 64) 256 ['conv2d_3[0][0]'] \n",
211
+ " rmalization) \n",
212
+ " \n",
213
+ " re_lu_4 (ReLU) (None, 64, 64, 64) 0 ['batch_normalization_3[0][0]'] \n",
214
+ " \n",
215
+ " conv2d_4 (Conv2D) (None, 64, 64, 3) 1728 ['re_lu_4[0][0]'] \n",
216
+ " \n",
217
+ " activation (Activation) (None, 64, 64, 3) 0 ['conv2d_4[0][0]'] \n",
218
+ " \n",
219
+ "==================================================================================================\n",
220
+ "Total params: 10,270,400\n",
221
+ "Trainable params: 10,268,480\n",
222
+ "Non-trainable params: 1,920\n",
223
+ "__________________________________________________________________________________________________\n"
224
+ ]
225
+ }
226
+ ],
227
+ "source": [
228
+ "generator = build_stage1_generator()\n",
229
+ "generator.summary()"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "markdown",
234
+ "id": "a14d9d1c",
235
+ "metadata": {},
236
+ "source": [
237
+ "### Stage 1 Discriminator Network"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": 9,
243
+ "id": "32b436ac",
244
+ "metadata": {},
245
+ "outputs": [],
246
+ "source": [
247
+ "def ConvBlock(x, num_kernels, kernel_size=(4,4), strides=2, activation=True):\n",
248
+ " \"\"\"A ConvBlock with a Conv2D, BatchNormalization and LeakyReLU activation.\n",
249
+ "\n",
250
+ " Args:\n",
251
+ " x: The preceding layer as input.\n",
252
+ " num_kernels: Number of kernels for the Conv2D layer.\n",
253
+ "\n",
254
+ " Returns:\n",
255
+ " x: The final activation layer after the ConvBlock block.\n",
256
+ " \"\"\"\n",
257
+ " x = Conv2D(num_kernels, kernel_size=kernel_size, padding='same', strides=strides, use_bias=False,\n",
258
+ " kernel_initializer='he_uniform')(x)\n",
259
+ " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n",
260
+ " \n",
261
+ " if activation:\n",
262
+ " x = LeakyReLU(alpha=0.2)(x)\n",
263
+ " return x\n",
264
+ "\n",
265
+ "\n",
266
+ "def build_embedding_compressor():\n",
267
+ " \"\"\"Build embedding compressor model\n",
268
+ " \"\"\"\n",
269
+ " input_layer1 = Input(shape=(1024,)) \n",
270
+ " x = Dense(128)(input_layer1)\n",
271
+ " x = ReLU()(x)\n",
272
+ "\n",
273
+ " model = Model(inputs=[input_layer1], outputs=[x])\n",
274
+ " return model\n",
275
+ "\n",
276
+ "# the discriminator is fed with two inputs, the feature from Generator and the text embedding\n",
277
+ "def build_stage1_discriminator():\n",
278
+ " \"\"\"Builds the Stage 1 Discriminator that uses the 64x64 resolution images from the generator\n",
279
+ " and the compressed and spatially replicated embedding.\n",
280
+ "\n",
281
+ " Returns:\n",
282
+ " Stage 1 Discriminator Model for StackGAN.\n",
283
+ " \"\"\"\n",
284
+ " input_layer1 = Input(shape=(64, 64, 3)) \n",
285
+ "\n",
286
+ " x = Conv2D(64, kernel_size=(4,4), strides=2, padding='same', use_bias=False,\n",
287
+ " kernel_initializer='he_uniform')(input_layer1)\n",
288
+ " x = LeakyReLU(alpha=0.2)(x)\n",
289
+ "\n",
290
+ " x = ConvBlock(x, 128)\n",
291
+ " x = ConvBlock(x, 256)\n",
292
+ " x = ConvBlock(x, 512)\n",
293
+ "\n",
294
+ " # Obtain the compressed and spatially replicated text embedding\n",
295
+ " input_layer2 = Input(shape=(4, 4, 128)) #2nd input to discriminator, text embedding\n",
296
+ " concat = concatenate([x, input_layer2])\n",
297
+ "\n",
298
+ " x1 = Conv2D(512, kernel_size=(1,1), padding='same', strides=1, use_bias=False,\n",
299
+ " kernel_initializer='he_uniform')(concat)\n",
300
+ " x1 = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n",
301
+ " x1 = LeakyReLU(alpha=0.2)(x)\n",
302
+ "\n",
303
+ " # Flatten and add a FC layer to predict.\n",
304
+ " x1 = Flatten()(x1)\n",
305
+ " x1 = Dense(1)(x1)\n",
306
+ " x1 = Activation('sigmoid')(x1)\n",
307
+ "\n",
308
+ " stage1_dis = Model(inputs=[input_layer1, input_layer2], outputs=[x1]) \n",
309
+ " return stage1_dis"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": 10,
315
+ "id": "98090438",
316
+ "metadata": {},
317
+ "outputs": [
318
+ {
319
+ "name": "stdout",
320
+ "output_type": "stream",
321
+ "text": [
322
+ "Model: \"model_1\"\n",
323
+ "__________________________________________________________________________________________________\n",
324
+ " Layer (type) Output Shape Param # Connected to \n",
325
+ "==================================================================================================\n",
326
+ " input_5 (InputLayer) [(None, 64, 64, 3)] 0 [] \n",
327
+ " \n",
328
+ " conv2d_9 (Conv2D) (None, 32, 32, 64) 3072 ['input_5[0][0]'] \n",
329
+ " \n",
330
+ " leaky_re_lu_5 (LeakyReLU) (None, 32, 32, 64) 0 ['conv2d_9[0][0]'] \n",
331
+ " \n",
332
+ " conv2d_10 (Conv2D) (None, 16, 16, 128) 131072 ['leaky_re_lu_5[0][0]'] \n",
333
+ " \n",
334
+ " batch_normalization_7 (BatchNo (None, 16, 16, 128) 512 ['conv2d_10[0][0]'] \n",
335
+ " rmalization) \n",
336
+ " \n",
337
+ " leaky_re_lu_6 (LeakyReLU) (None, 16, 16, 128) 0 ['batch_normalization_7[0][0]'] \n",
338
+ " \n",
339
+ " conv2d_11 (Conv2D) (None, 8, 8, 256) 524288 ['leaky_re_lu_6[0][0]'] \n",
340
+ " \n",
341
+ " batch_normalization_8 (BatchNo (None, 8, 8, 256) 1024 ['conv2d_11[0][0]'] \n",
342
+ " rmalization) \n",
343
+ " \n",
344
+ " leaky_re_lu_7 (LeakyReLU) (None, 8, 8, 256) 0 ['batch_normalization_8[0][0]'] \n",
345
+ " \n",
346
+ " conv2d_12 (Conv2D) (None, 4, 4, 512) 2097152 ['leaky_re_lu_7[0][0]'] \n",
347
+ " \n",
348
+ " batch_normalization_9 (BatchNo (None, 4, 4, 512) 2048 ['conv2d_12[0][0]'] \n",
349
+ " rmalization) \n",
350
+ " \n",
351
+ " leaky_re_lu_8 (LeakyReLU) (None, 4, 4, 512) 0 ['batch_normalization_9[0][0]'] \n",
352
+ " \n",
353
+ " leaky_re_lu_9 (LeakyReLU) (None, 4, 4, 512) 0 ['leaky_re_lu_8[0][0]'] \n",
354
+ " \n",
355
+ " flatten (Flatten) (None, 8192) 0 ['leaky_re_lu_9[0][0]'] \n",
356
+ " \n",
357
+ " dense_2 (Dense) (None, 1) 8193 ['flatten[0][0]'] \n",
358
+ " \n",
359
+ " input_6 (InputLayer) [(None, 4, 4, 128)] 0 [] \n",
360
+ " \n",
361
+ " activation_1 (Activation) (None, 1) 0 ['dense_2[0][0]'] \n",
362
+ " \n",
363
+ "==================================================================================================\n",
364
+ "Total params: 2,767,361\n",
365
+ "Trainable params: 2,765,569\n",
366
+ "Non-trainable params: 1,792\n",
367
+ "__________________________________________________________________________________________________\n"
368
+ ]
369
+ }
370
+ ],
371
+ "source": [
372
+ "discriminator = build_stage1_discriminator()\n",
373
+ "discriminator.summary()"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "markdown",
378
+ "id": "cdc2a75a",
379
+ "metadata": {},
380
+ "source": [
381
+ "### Stage 1 Adversarial Model (Building a GAN)"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "code",
386
+ "execution_count": 11,
387
+ "id": "5d0678f7",
388
+ "metadata": {},
389
+ "outputs": [],
390
+ "source": [
391
+ "# Building GAN with Generator and Discriminator\n",
392
+ "\n",
393
+ "def build_adversarial(generator_model, discriminator_model):\n",
394
+ " \"\"\"Stage 1 Adversarial model.\n",
395
+ "\n",
396
+ " Args:\n",
397
+ " generator_model: Stage 1 Generator Model\n",
398
+ " discriminator_model: Stage 1 Discriminator Model\n",
399
+ "\n",
400
+ " Returns:\n",
401
+ " Adversarial Model.\n",
402
+ " \"\"\"\n",
403
+ " input_layer1 = Input(shape=(1024,)) \n",
404
+ " input_layer2 = Input(shape=(100,)) \n",
405
+ " input_layer3 = Input(shape=(4, 4, 128)) \n",
406
+ "\n",
407
+ " x, ca = generator_model([input_layer1, input_layer2]) #text,noise\n",
408
+ "\n",
409
+ " discriminator_model.trainable = False \n",
410
+ "\n",
411
+ " probabilities = discriminator_model([x, input_layer3]) \n",
412
+ " adversarial_model = Model(inputs=[input_layer1, input_layer2, input_layer3], outputs=[probabilities, ca])\n",
413
+ " return adversarial_model"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": 12,
419
+ "id": "bd351c9d",
420
+ "metadata": {},
421
+ "outputs": [
422
+ {
423
+ "name": "stdout",
424
+ "output_type": "stream",
425
+ "text": [
426
+ "Model: \"model_2\"\n",
427
+ "__________________________________________________________________________________________________\n",
428
+ " Layer (type) Output Shape Param # Connected to \n",
429
+ "==================================================================================================\n",
430
+ " input_7 (InputLayer) [(None, 1024)] 0 [] \n",
431
+ " \n",
432
+ " input_8 (InputLayer) [(None, 100)] 0 [] \n",
433
+ " \n",
434
+ " model (Functional) [(None, 64, 64, 3), 10270400 ['input_7[0][0]', \n",
435
+ " (None, 256)] 'input_8[0][0]'] \n",
436
+ " \n",
437
+ " input_9 (InputLayer) [(None, 4, 4, 128)] 0 [] \n",
438
+ " \n",
439
+ " model_1 (Functional) (None, 1) 2767361 ['model[0][0]', \n",
440
+ " 'input_9[0][0]'] \n",
441
+ " \n",
442
+ "==================================================================================================\n",
443
+ "Total params: 13,037,761\n",
444
+ "Trainable params: 10,268,480\n",
445
+ "Non-trainable params: 2,769,281\n",
446
+ "__________________________________________________________________________________________________\n"
447
+ ]
448
+ }
449
+ ],
450
+ "source": [
451
+ "ganstage1 = build_adversarial(generator, discriminator)\n",
452
+ "ganstage1.summary()"
453
+ ]
454
+ },
455
+ {
456
+ "cell_type": "markdown",
457
+ "id": "adf70416",
458
+ "metadata": {},
459
+ "source": [
460
+ "### Train Utilities"
461
+ ]
462
+ },
463
+ {
464
+ "cell_type": "code",
465
+ "execution_count": 13,
466
+ "id": "730c9e8a",
467
+ "metadata": {},
468
+ "outputs": [],
469
+ "source": [
470
+ "def checkpoint_prefix():\n",
471
+ " checkpoint_dir = './training_checkpoints'\n",
472
+ " checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')\n",
473
+ "\n",
474
+ " return checkpoint_prefix\n",
475
+ "\n",
476
+ "def adversarial_loss(y_true, y_pred):\n",
477
+ " mean = y_pred[:, :128]\n",
478
+ " ls = y_pred[:, 128:]\n",
479
+ " loss = -ls + 0.5 * (-1 + tf.math.exp(2.0 * ls) + tf.math.square(mean))\n",
480
+ " loss = K.mean(loss)\n",
481
+ " return loss\n",
482
+ "\n",
483
+ "def normalize(input_image, real_image):\n",
484
+ " input_image = (input_image / 127.5) - 1\n",
485
+ " real_image = (real_image / 127.5) - 1\n",
486
+ "\n",
487
+ " return input_image, real_image\n",
488
+ "\n",
489
+ "def load_class_ids_filenames(class_id_path, filename_path):\n",
490
+ " with open(class_id_path, 'rb') as file:\n",
491
+ " class_id = pickle.load(file, encoding='latin1')\n",
492
+ "\n",
493
+ " with open(filename_path, 'rb') as file:\n",
494
+ " filename = pickle.load(file, encoding='latin1')\n",
495
+ "\n",
496
+ " return class_id, filename\n",
497
+ "\n",
498
+ "def load_text_embeddings(text_embeddings):\n",
499
+ " with open(text_embeddings, 'rb') as file:\n",
500
+ " embeds = pickle.load(file, encoding='latin1')\n",
501
+ " embeds = np.array(embeds)\n",
502
+ "\n",
503
+ " return embeds\n",
504
+ "\n",
505
+ "def load_bbox(data_path):\n",
506
+ " bbox_path = data_path + '/bounding_boxes.txt'\n",
507
+ " image_path = data_path + '/images.txt'\n",
508
+ " bbox_df = pd.read_csv(bbox_path, delim_whitespace=True, header=None).astype(int)\n",
509
+ " filename_df = pd.read_csv(image_path, delim_whitespace=True, header=None)\n",
510
+ "\n",
511
+ " filenames = filename_df[1].tolist()\n",
512
+ " bbox_dict = {i[:-4]:[] for i in filenames[:2]}\n",
513
+ "\n",
514
+ " for i in range(0, len(filenames)):\n",
515
+ " bbox = bbox_df.iloc[i][1:].tolist()\n",
516
+ " dict_key = filenames[i][:-4]\n",
517
+ " bbox_dict[dict_key] = bbox\n",
518
+ "\n",
519
+ " return bbox_dict\n",
520
+ "\n",
521
+ "def load_images(image_path, bounding_box, size):\n",
522
+ " \"\"\"Crops the image to the bounding box and then resizes it.\n",
523
+ " \"\"\"\n",
524
+ " image = Image.open(image_path).convert('RGB')\n",
525
+ " w, h = image.size\n",
526
+ " if bounding_box is not None:\n",
527
+ " r = int(np.maximum(bounding_box[2], bounding_box[3]) * 0.75)\n",
528
+ " c_x = int((bounding_box[0] + bounding_box[2]) / 2)\n",
529
+ " c_y = int((bounding_box[1] + bounding_box[3]) / 2)\n",
530
+ " y1 = np.maximum(0, c_y - r)\n",
531
+ " y2 = np.minimum(h, c_y + r)\n",
532
+ " x1 = np.maximum(0, c_x - r)\n",
533
+ " x2 = np.minimum(w, c_x + r)\n",
534
+ " image = image.crop([x1, y1, x2, y2])\n",
535
+ "\n",
536
+ " image = image.resize(size, PIL.Image.BILINEAR)\n",
537
+ " return image\n",
538
+ "\n",
539
+ "def load_data(filename_path, class_id_path, dataset_path, embeddings_path, size):\n",
540
+ " \"\"\"Loads the Dataset.\n",
541
+ " \"\"\"\n",
542
+ " data_dir = \"D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/birds\"\n",
543
+ " train_dir = data_dir + \"/train\"\n",
544
+ " test_dir = data_dir + \"/test\"\n",
545
+ " embeddings_path_train = train_dir + \"/char-CNN-RNN-embeddings.pickle\"\n",
546
+ " embeddings_path_test = test_dir + \"/char-CNN-RNN-embeddings.pickle\"\n",
547
+ " filename_path_train = train_dir + \"/filenames.pickle\"\n",
548
+ " filename_path_test = test_dir + \"/filenames.pickle\"\n",
549
+ " class_id_path_train = train_dir + \"/class_info.pickle\"\n",
550
+ " class_id_path_test = test_dir + \"/class_info.pickle\"\n",
551
+ " dataset_path = \"D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/CUB_200_2011\"\n",
552
+ " class_id, filenames = load_class_ids_filenames(class_id_path, filename_path)\n",
553
+ " embeddings = load_text_embeddings(embeddings_path)\n",
554
+ " bbox_dict = load_bbox(dataset_path)\n",
555
+ "\n",
556
+ " x, y, embeds = [], [], []\n",
557
+ "\n",
558
+ " for i, filename in enumerate(filenames):\n",
559
+ " bbox = bbox_dict[filename]\n",
560
+ "\n",
561
+ " try:\n",
562
+ " image_path = f'{dataset_path}/images/{filename}.jpg'\n",
563
+ " image = load_images(image_path, bbox, size)\n",
564
+ " e = embeddings[i, :, :]\n",
565
+ " embed_index = np.random.randint(0, e.shape[0] - 1)\n",
566
+ " embed = e[embed_index, :]\n",
567
+ "\n",
568
+ " x.append(np.array(image))\n",
569
+ " y.append(class_id[i])\n",
570
+ " embeds.append(embed)\n",
571
+ "\n",
572
+ " except Exception as e:\n",
573
+ " print(f'{e}')\n",
574
+ " \n",
575
+ " x = np.array(x)\n",
576
+ " y = np.array(y)\n",
577
+ " embeds = np.array(embeds)\n",
578
+ " \n",
579
+ " return x, y, embeds\n",
580
+ "\n",
581
+ "def save_image(file, save_path):\n",
582
+ " \"\"\"Saves the image at the specified file path.\n",
583
+ " \"\"\"\n",
584
+ " image = plt.figure()\n",
585
+ " ax = image.add_subplot(1,1,1)\n",
586
+ " ax.imshow(file)\n",
587
+ " ax.axis(\"off\")\n",
588
+ " plt.savefig(save_path)"
589
+ ]
590
+ },
591
+ {
592
+ "cell_type": "code",
593
+ "execution_count": 28,
594
+ "id": "697f1dc6",
595
+ "metadata": {},
596
+ "outputs": [],
597
+ "source": [
598
+ "############################################################\n",
599
+ "# StackGAN class\n",
600
+ "############################################################\n",
601
+ "\n",
602
+ "class StackGanStage1(object):\n",
603
+ " \"\"\"StackGAN Stage 1 class.\"\"\"\n",
604
+ "\n",
605
+ " data_dir = \"D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/birds\"\n",
606
+ " train_dir = data_dir + \"/train\"\n",
607
+ " test_dir = data_dir + \"/test\"\n",
608
+ " embeddings_path_train = train_dir + \"/char-CNN-RNN-embeddings.pickle\"\n",
609
+ " embeddings_path_test = test_dir + \"/char-CNN-RNN-embeddings.pickle\"\n",
610
+ " filename_path_train = train_dir + \"/filenames.pickle\"\n",
611
+ " filename_path_test = test_dir + \"/filenames.pickle\"\n",
612
+ " class_id_path_train = train_dir + \"/class_info.pickle\"\n",
613
+ " class_id_path_test = test_dir + \"/class_info.pickle\"\n",
614
+ " dataset_path = \"D:/1-pipelined_topics/GAN_texttoimage/birds_implementation/CUB_200_2011\"\n",
615
+ " def __init__(self, epochs=500, z_dim=100, batch_size=64, enable_function=True, stage1_generator_lr=0.0002, stage1_discriminator_lr=0.0002):\n",
616
+ " self.epochs = epochs\n",
617
+ " self.z_dim = z_dim\n",
618
+ " self.enable_function = enable_function\n",
619
+ " self.stage1_generator_lr = stage1_generator_lr\n",
620
+ " self.stage1_discriminator_lr = stage1_discriminator_lr\n",
621
+ " self.image_size = 64\n",
622
+ " self.conditioning_dim = 128\n",
623
+ " self.batch_size = batch_size\n",
624
+ " \n",
625
+ " self.stage1_generator_optimizer = Adam(lr=stage1_generator_lr, beta_1=0.5, beta_2=0.999)\n",
626
+ " self.stage1_discriminator_optimizer = Adam(lr=stage1_discriminator_lr, beta_1=0.5, beta_2=0.999)\n",
627
+ " \n",
628
+ " self.stage1_generator = build_stage1_generator()\n",
629
+ " self.stage1_generator.compile(loss='mse', optimizer=self.stage1_generator_optimizer)\n",
630
+ "\n",
631
+ " self.stage1_discriminator = build_stage1_discriminator()\n",
632
+ " self.stage1_discriminator.compile(loss='binary_crossentropy', optimizer=self.stage1_discriminator_optimizer)\n",
633
+ "\n",
634
+ " self.ca_network = build_ca_network()\n",
635
+ " self.ca_network.compile(loss='binary_crossentropy', optimizer='Adam')\n",
636
+ "\n",
637
+ " self.embedding_compressor = build_embedding_compressor()\n",
638
+ " self.embedding_compressor.compile(loss='binary_crossentropy', optimizer='Adam')\n",
639
+ "\n",
640
+ " self.stage1_adversarial = build_adversarial(self.stage1_generator, self.stage1_discriminator)\n",
641
+ " self.stage1_adversarial.compile(loss=['binary_crossentropy', adversarial_loss], loss_weights=[1, 2.0], optimizer=self.stage1_generator_optimizer)\n",
642
+ "\n",
643
+ " self.checkpoint1 = tf.train.Checkpoint(\n",
644
+ " generator_optimizer=self.stage1_generator_optimizer,\n",
645
+ " discriminator_optimizer=self.stage1_discriminator_optimizer,\n",
646
+ " generator=self.stage1_generator,\n",
647
+ " discriminator=self.stage1_discriminator)\n",
648
+ "\n",
649
+ " def visualize_stage1(self):\n",
650
+ " \"\"\"Running Tensorboard visualizations.\n",
651
+ " \"\"\"\n",
652
+ " tb = TensorBoard(log_dir=\"logs/\".format(time.time()))\n",
653
+ " tb.set_model(self.stage1_generator)\n",
654
+ " tb.set_model(self.stage1_discriminator)\n",
655
+ " tb.set_model(self.ca_network)\n",
656
+ " tb.set_model(self.embedding_compressor)\n",
657
+ "\n",
658
+ " def train_stage1(self):\n",
659
+ " \"\"\"Trains the stage1 StackGAN.\n",
660
+ " \"\"\"\n",
661
+ " x_train, y_train, train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,\n",
662
+ " dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(64, 64))\n",
663
+ "\n",
664
+ " x_test, y_test, test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test, \n",
665
+ " dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(64, 64))\n",
666
+ "\n",
667
+ " real = np.ones((self.batch_size, 1), dtype='float') * 0.9\n",
668
+ " fake = np.zeros((self.batch_size, 1), dtype='float') * 0.1\n",
669
+ "\n",
670
+ " for epoch in range(self.epochs):\n",
671
+ " print(f'Epoch: {epoch}')\n",
672
+ "\n",
673
+ " gen_loss = []\n",
674
+ " dis_loss = []\n",
675
+ "\n",
676
+ " num_batches = int(x_train.shape[0] / self.batch_size)\n",
677
+ "\n",
678
+ " for i in range(num_batches):\n",
679
+ "\n",
680
+ " latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))\n",
681
+ " embedding_text = train_embeds[i * self.batch_size:(i + 1) * self.batch_size]\n",
682
+ " compressed_embedding = self.embedding_compressor.predict_on_batch(embedding_text)\n",
683
+ " compressed_embedding = np.reshape(compressed_embedding, (-1, 1, 1, 128))\n",
684
+ " compressed_embedding = np.tile(compressed_embedding, (1, 4, 4, 1))\n",
685
+ "\n",
686
+ " image_batch = x_train[i * self.batch_size:(i+1) * self.batch_size]\n",
687
+ " image_batch = (image_batch - 127.5) / 127.5\n",
688
+ "\n",
689
+ " gen_images, _ = self.stage1_generator.predict([embedding_text, latent_space])\n",
690
+ "\n",
691
+ " discriminator_loss = self.stage1_discriminator.train_on_batch([image_batch, compressed_embedding], \n",
692
+ " np.reshape(real, (self.batch_size, 1)))\n",
693
+ "\n",
694
+ " discriminator_loss_gen = self.stage1_discriminator.train_on_batch([gen_images, compressed_embedding],\n",
695
+ " np.reshape(fake, (self.batch_size, 1)))\n",
696
+ "\n",
697
+ " discriminator_loss_wrong = self.stage1_discriminator.train_on_batch([gen_images[: self.batch_size-1], compressed_embedding[1:]], \n",
698
+ " np.reshape(fake[1:], (self.batch_size-1, 1)))\n",
699
+ "\n",
700
+ "# Discriminator loss\n",
701
+ " d_loss = 0.5 * np.add(discriminator_loss, 0.5 * np.add(discriminator_loss_gen, discriminator_loss_wrong))\n",
702
+ " dis_loss.append(d_loss)\n",
703
+ "\n",
704
+ " print(f'Discriminator Loss: {d_loss}')\n",
705
+ "\n",
706
+ " # Generator loss\n",
707
+ " g_loss = self.stage1_adversarial.train_on_batch([embedding_text, latent_space, compressed_embedding],\n",
708
+ " [K.ones((self.batch_size, 1)) * 0.9, K.ones((self.batch_size, 256)) * 0.9])\n",
709
+ "\n",
710
+ " print(f'Generator Loss: {g_loss}')\n",
711
+ " gen_loss.append(g_loss)\n",
712
+ "\n",
713
+ " if epoch % 5 == 0:\n",
714
+ " latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))\n",
715
+ " embedding_batch = test_embeds[0 : self.batch_size]\n",
716
+ " gen_images, _ = self.stage1_generator.predict_on_batch([embedding_batch, latent_space])\n",
717
+ "\n",
718
+ " for i, image in enumerate(gen_images[:10]):\n",
719
+ " save_image(image, f'test/gen_1_{epoch}_{i}')\n",
720
+ "\n",
721
+ " if epoch % 25 == 0:\n",
722
+ " self.stage1_generator.save_weights('weights/stage1_gen.h5')\n",
723
+ " self.stage1_discriminator.save_weights(\"weights/stage1_disc.h5\")\n",
724
+ " self.ca_network.save_weights('weights/stage1_ca.h5')\n",
725
+ " self.embedding_compressor.save_weights('weights/stage1_embco.h5')\n",
726
+ " self.stage1_adversarial.save_weights('weights/stage1_adv.h5') \n",
727
+ "\n",
728
+ " self.stage1_generator.save_weights('weights/stage1_gen.h5')\n",
729
+ " self.stage1_discriminator.save_weights(\"weights/stage1_disc.h5\")"
730
+ ]
731
+ },
732
+ {
733
+ "cell_type": "code",
734
+ "execution_count": null,
735
+ "id": "517037ac",
736
+ "metadata": {},
737
+ "outputs": [],
738
+ "source": [
739
+ "stage1 = StackGanStage1()\n",
740
+ "stage1.train_stage1()"
741
+ ]
742
+ },
743
+ {
744
+ "cell_type": "markdown",
745
+ "id": "7d85b9f2",
746
+ "metadata": {},
747
+ "source": [
748
+ "### Check test folder for gernerated images from Stage1 Generator\n",
749
+ "### Let's Implement Stage 2 Generator"
750
+ ]
751
+ },
752
+ {
753
+ "cell_type": "code",
754
+ "execution_count": 29,
755
+ "id": "2e45c731",
756
+ "metadata": {},
757
+ "outputs": [],
758
+ "source": [
759
+ "############################################################\n",
760
+ "# Stage 2 Generator Network\n",
761
+ "############################################################\n",
762
+ "\n",
763
+ "def concat_along_dims(inputs):\n",
764
+ " \"\"\"Joins the conditioned text with the encoded image along the dimensions.\n",
765
+ "\n",
766
+ " Args:\n",
767
+ " inputs: consisting of conditioned text and encoded images as [c,x].\n",
768
+ "\n",
769
+ " Returns:\n",
770
+ " Joint block along the dimensions.\n",
771
+ " \"\"\"\n",
772
+ " c = inputs[0]\n",
773
+ " x = inputs[1]\n",
774
+ "\n",
775
+ " c = K.expand_dims(c, axis=1)\n",
776
+ " c = K.expand_dims(c, axis=1)\n",
777
+ " c = K.tile(c, [1, 16, 16, 1])\n",
778
+ " return K.concatenate([c, x], axis = 3)\n",
779
+ "\n",
780
+ "def residual_block(input):\n",
781
+ " \"\"\"Residual block with plain identity connections.\n",
782
+ "\n",
783
+ " Args:\n",
784
+ " inputs: input layer or an encoded layer\n",
785
+ "\n",
786
+ " Returns:\n",
787
+ " Layer with computed identity mapping.\n",
788
+ " \"\"\"\n",
789
+ " x = Conv2D(512, kernel_size=(3,3), padding='same', use_bias=False,\n",
790
+ " kernel_initializer='he_uniform')(input)\n",
791
+ " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n",
792
+ " x = ReLU()(x)\n",
793
+ " \n",
794
+ " x = Conv2D(512, kernel_size=(3,3), padding='same', use_bias=False,\n",
795
+ " kernel_initializer='he_uniform')(x)\n",
796
+ " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n",
797
+ " \n",
798
+ " x = add([x, input])\n",
799
+ " x = ReLU()(x)\n",
800
+ "\n",
801
+ " return x\n",
802
+ "\n",
803
+ "def build_stage2_generator():\n",
804
+ " \"\"\"Build the Stage 2 Generator Network using the conditioning text and images from stage 1.\n",
805
+ "\n",
806
+ " Returns:\n",
807
+ " Stage 2 Generator Model for StackGAN.\n",
808
+ " \"\"\"\n",
809
+ " input_layer1 = Input(shape=(1024,))\n",
810
+ " input_images = Input(shape=(64, 64, 3))\n",
811
+ "\n",
812
+ " # Conditioning Augmentation\n",
813
+ " ca = Dense(256)(input_layer1)\n",
814
+ " mls = LeakyReLU(alpha=0.2)(ca)\n",
815
+ " c = Lambda(conditioning_augmentation)(mls)\n",
816
+ "\n",
817
+ " # Downsampling block\n",
818
+ " x = ZeroPadding2D(padding=(1,1))(input_images)\n",
819
+ " x = Conv2D(128, kernel_size=(3,3), strides=1, use_bias=False,\n",
820
+ " kernel_initializer='he_uniform')(x)\n",
821
+ " x = ReLU()(x)\n",
822
+ "\n",
823
+ " x = ZeroPadding2D(padding=(1,1))(x)\n",
824
+ " x = Conv2D(256, kernel_size=(4,4), strides=2, use_bias=False,\n",
825
+ " kernel_initializer='he_uniform')(x)\n",
826
+ " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n",
827
+ " x = ReLU()(x)\n",
828
+ "\n",
829
+ " x = ZeroPadding2D(padding=(1,1))(x)\n",
830
+ " x = Conv2D(512, kernel_size=(4,4), strides=2, use_bias=False,\n",
831
+ " kernel_initializer='he_uniform')(x)\n",
832
+ " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n",
833
+ " x = ReLU()(x)\n",
834
+ "\n",
835
+ " # Concatenate text conditioning block with the encoded image\n",
836
+ " concat = concat_along_dims([c, x])\n",
837
+ "\n",
838
+ " # Residual Blocks\n",
839
+ " x = ZeroPadding2D(padding=(1,1))(concat)\n",
840
+ " x = Conv2D(512, kernel_size=(3,3), use_bias=False, kernel_initializer='he_uniform')(x)\n",
841
+ " x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)\n",
842
+ " x = ReLU()(x)\n",
843
+ "\n",
844
+ " x = residual_block(x)\n",
845
+ " x = residual_block(x)\n",
846
+ " x = residual_block(x)\n",
847
+ " x = residual_block(x)\n",
848
+ " \n",
849
+ " # Upsampling Blocks\n",
850
+ " x = UpSamplingBlock(x, 512)\n",
851
+ " x = UpSamplingBlock(x, 256)\n",
852
+ " x = UpSamplingBlock(x, 128)\n",
853
+ " x = UpSamplingBlock(x, 64)\n",
854
+ "\n",
855
+ " x = Conv2D(3, kernel_size=(3,3), padding='same', use_bias=False, kernel_initializer='he_uniform')(x)\n",
856
+ " x = Activation('tanh')(x)\n",
857
+ " \n",
858
+ " stage2_gen = Model(inputs=[input_layer1, input_images], outputs=[x, mls])\n",
859
+ " return stage2_gen"
860
+ ]
861
+ },
862
+ {
863
+ "cell_type": "code",
864
+ "execution_count": 30,
865
+ "id": "76c876db",
866
+ "metadata": {},
867
+ "outputs": [
868
+ {
869
+ "name": "stdout",
870
+ "output_type": "stream",
871
+ "text": [
872
+ "Model: \"model_3\"\n",
873
+ "__________________________________________________________________________________________________\n",
874
+ " Layer (type) Output Shape Param # Connected to \n",
875
+ "==================================================================================================\n",
876
+ " input_11 (InputLayer) [(None, 64, 64, 3)] 0 [] \n",
877
+ " \n",
878
+ " zero_padding2d (ZeroPadding2D) (None, 66, 66, 3) 0 ['input_11[0][0]'] \n",
879
+ " \n",
880
+ " conv2d_14 (Conv2D) (None, 64, 64, 128) 3456 ['zero_padding2d[0][0]'] \n",
881
+ " \n",
882
+ " re_lu_5 (ReLU) (None, 64, 64, 128) 0 ['conv2d_14[0][0]'] \n",
883
+ " \n",
884
+ " zero_padding2d_1 (ZeroPadding2 (None, 66, 66, 128) 0 ['re_lu_5[0][0]'] \n",
885
+ " D) \n",
886
+ " \n",
887
+ " input_10 (InputLayer) [(None, 1024)] 0 [] \n",
888
+ " \n",
889
+ " conv2d_15 (Conv2D) (None, 32, 32, 256) 524288 ['zero_padding2d_1[0][0]'] \n",
890
+ " \n",
891
+ " dense_3 (Dense) (None, 256) 262400 ['input_10[0][0]'] \n",
892
+ " \n",
893
+ " batch_normalization_11 (BatchN (None, 32, 32, 256) 1024 ['conv2d_15[0][0]'] \n",
894
+ " ormalization) \n",
895
+ " \n",
896
+ " leaky_re_lu_10 (LeakyReLU) (None, 256) 0 ['dense_3[0][0]'] \n",
897
+ " \n",
898
+ " re_lu_6 (ReLU) (None, 32, 32, 256) 0 ['batch_normalization_11[0][0]'] \n",
899
+ " \n",
900
+ " lambda_1 (Lambda) (None, 128) 0 ['leaky_re_lu_10[0][0]'] \n",
901
+ " \n",
902
+ " zero_padding2d_2 (ZeroPadding2 (None, 34, 34, 256) 0 ['re_lu_6[0][0]'] \n",
903
+ " D) \n",
904
+ " \n",
905
+ " tf.expand_dims (TFOpLambda) (None, 1, 128) 0 ['lambda_1[0][0]'] \n",
906
+ " \n",
907
+ " conv2d_16 (Conv2D) (None, 16, 16, 512) 2097152 ['zero_padding2d_2[0][0]'] \n",
908
+ " \n",
909
+ " tf.expand_dims_1 (TFOpLambda) (None, 1, 1, 128) 0 ['tf.expand_dims[0][0]'] \n",
910
+ " \n",
911
+ " batch_normalization_12 (BatchN (None, 16, 16, 512) 2048 ['conv2d_16[0][0]'] \n",
912
+ " ormalization) \n",
913
+ " \n",
914
+ " tf.tile (TFOpLambda) (None, 16, 16, 128) 0 ['tf.expand_dims_1[0][0]'] \n",
915
+ " \n",
916
+ " re_lu_7 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_12[0][0]'] \n",
917
+ " \n",
918
+ " tf.concat (TFOpLambda) (None, 16, 16, 640) 0 ['tf.tile[0][0]', \n",
919
+ " 're_lu_7[0][0]'] \n",
920
+ " \n",
921
+ " zero_padding2d_3 (ZeroPadding2 (None, 18, 18, 640) 0 ['tf.concat[0][0]'] \n",
922
+ " D) \n",
923
+ " \n",
924
+ " conv2d_17 (Conv2D) (None, 16, 16, 512) 2949120 ['zero_padding2d_3[0][0]'] \n",
925
+ " \n",
926
+ " batch_normalization_13 (BatchN (None, 16, 16, 512) 2048 ['conv2d_17[0][0]'] \n",
927
+ " ormalization) \n",
928
+ " \n",
929
+ " re_lu_8 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_13[0][0]'] \n",
930
+ " \n",
931
+ " conv2d_18 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_8[0][0]'] \n",
932
+ " \n",
933
+ " batch_normalization_14 (BatchN (None, 16, 16, 512) 2048 ['conv2d_18[0][0]'] \n",
934
+ " ormalization) \n",
935
+ " \n",
936
+ " re_lu_9 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_14[0][0]'] \n",
937
+ " \n",
938
+ " conv2d_19 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_9[0][0]'] \n",
939
+ " \n",
940
+ " batch_normalization_15 (BatchN (None, 16, 16, 512) 2048 ['conv2d_19[0][0]'] \n",
941
+ " ormalization) \n",
942
+ " \n",
943
+ " add (Add) (None, 16, 16, 512) 0 ['batch_normalization_15[0][0]', \n",
944
+ " 're_lu_8[0][0]'] \n",
945
+ " \n",
946
+ " re_lu_10 (ReLU) (None, 16, 16, 512) 0 ['add[0][0]'] \n",
947
+ " \n",
948
+ " conv2d_20 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_10[0][0]'] \n",
949
+ " \n",
950
+ " batch_normalization_16 (BatchN (None, 16, 16, 512) 2048 ['conv2d_20[0][0]'] \n",
951
+ " ormalization) \n",
952
+ " \n",
953
+ " re_lu_11 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_16[0][0]'] \n",
954
+ " \n",
955
+ " conv2d_21 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_11[0][0]'] \n"
956
+ ]
957
+ },
958
+ {
959
+ "name": "stdout",
960
+ "output_type": "stream",
961
+ "text": [
962
+ " \n",
963
+ " batch_normalization_17 (BatchN (None, 16, 16, 512) 2048 ['conv2d_21[0][0]'] \n",
964
+ " ormalization) \n",
965
+ " \n",
966
+ " add_1 (Add) (None, 16, 16, 512) 0 ['batch_normalization_17[0][0]', \n",
967
+ " 're_lu_10[0][0]'] \n",
968
+ " \n",
969
+ " re_lu_12 (ReLU) (None, 16, 16, 512) 0 ['add_1[0][0]'] \n",
970
+ " \n",
971
+ " conv2d_22 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_12[0][0]'] \n",
972
+ " \n",
973
+ " batch_normalization_18 (BatchN (None, 16, 16, 512) 2048 ['conv2d_22[0][0]'] \n",
974
+ " ormalization) \n",
975
+ " \n",
976
+ " re_lu_13 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_18[0][0]'] \n",
977
+ " \n",
978
+ " conv2d_23 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_13[0][0]'] \n",
979
+ " \n",
980
+ " batch_normalization_19 (BatchN (None, 16, 16, 512) 2048 ['conv2d_23[0][0]'] \n",
981
+ " ormalization) \n",
982
+ " \n",
983
+ " add_2 (Add) (None, 16, 16, 512) 0 ['batch_normalization_19[0][0]', \n",
984
+ " 're_lu_12[0][0]'] \n",
985
+ " \n",
986
+ " re_lu_14 (ReLU) (None, 16, 16, 512) 0 ['add_2[0][0]'] \n",
987
+ " \n",
988
+ " conv2d_24 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_14[0][0]'] \n",
989
+ " \n",
990
+ " batch_normalization_20 (BatchN (None, 16, 16, 512) 2048 ['conv2d_24[0][0]'] \n",
991
+ " ormalization) \n",
992
+ " \n",
993
+ " re_lu_15 (ReLU) (None, 16, 16, 512) 0 ['batch_normalization_20[0][0]'] \n",
994
+ " \n",
995
+ " conv2d_25 (Conv2D) (None, 16, 16, 512) 2359296 ['re_lu_15[0][0]'] \n",
996
+ " \n",
997
+ " batch_normalization_21 (BatchN (None, 16, 16, 512) 2048 ['conv2d_25[0][0]'] \n",
998
+ " ormalization) \n",
999
+ " \n",
1000
+ " add_3 (Add) (None, 16, 16, 512) 0 ['batch_normalization_21[0][0]', \n",
1001
+ " 're_lu_14[0][0]'] \n",
1002
+ " \n",
1003
+ " re_lu_16 (ReLU) (None, 16, 16, 512) 0 ['add_3[0][0]'] \n",
1004
+ " \n",
1005
+ " up_sampling2d_4 (UpSampling2D) (None, 32, 32, 512) 0 ['re_lu_16[0][0]'] \n",
1006
+ " \n",
1007
+ " conv2d_26 (Conv2D) (None, 32, 32, 512) 2359296 ['up_sampling2d_4[0][0]'] \n",
1008
+ " \n",
1009
+ " batch_normalization_22 (BatchN (None, 32, 32, 512) 2048 ['conv2d_26[0][0]'] \n",
1010
+ " ormalization) \n",
1011
+ " \n",
1012
+ " re_lu_17 (ReLU) (None, 32, 32, 512) 0 ['batch_normalization_22[0][0]'] \n",
1013
+ " \n",
1014
+ " up_sampling2d_5 (UpSampling2D) (None, 64, 64, 512) 0 ['re_lu_17[0][0]'] \n",
1015
+ " \n",
1016
+ " conv2d_27 (Conv2D) (None, 64, 64, 256) 1179648 ['up_sampling2d_5[0][0]'] \n",
1017
+ " \n",
1018
+ " batch_normalization_23 (BatchN (None, 64, 64, 256) 1024 ['conv2d_27[0][0]'] \n",
1019
+ " ormalization) \n",
1020
+ " \n",
1021
+ " re_lu_18 (ReLU) (None, 64, 64, 256) 0 ['batch_normalization_23[0][0]'] \n",
1022
+ " \n",
1023
+ " up_sampling2d_6 (UpSampling2D) (None, 128, 128, 25 0 ['re_lu_18[0][0]'] \n",
1024
+ " 6) \n",
1025
+ " \n",
1026
+ " conv2d_28 (Conv2D) (None, 128, 128, 12 294912 ['up_sampling2d_6[0][0]'] \n",
1027
+ " 8) \n",
1028
+ " \n",
1029
+ " batch_normalization_24 (BatchN (None, 128, 128, 12 512 ['conv2d_28[0][0]'] \n",
1030
+ " ormalization) 8) \n",
1031
+ " \n",
1032
+ " re_lu_19 (ReLU) (None, 128, 128, 12 0 ['batch_normalization_24[0][0]'] \n",
1033
+ " 8) \n",
1034
+ " \n",
1035
+ " up_sampling2d_7 (UpSampling2D) (None, 256, 256, 12 0 ['re_lu_19[0][0]'] \n",
1036
+ " 8) \n",
1037
+ " \n",
1038
+ " conv2d_29 (Conv2D) (None, 256, 256, 64 73728 ['up_sampling2d_7[0][0]'] \n",
1039
+ " ) \n",
1040
+ " \n",
1041
+ " batch_normalization_25 (BatchN (None, 256, 256, 64 256 ['conv2d_29[0][0]'] \n",
1042
+ " ormalization) ) \n",
1043
+ " \n",
1044
+ " re_lu_20 (ReLU) (None, 256, 256, 64 0 ['batch_normalization_25[0][0]'] \n"
1045
+ ]
1046
+ },
1047
+ {
1048
+ "name": "stdout",
1049
+ "output_type": "stream",
1050
+ "text": [
1051
+ " ) \n",
1052
+ " \n",
1053
+ " conv2d_30 (Conv2D) (None, 256, 256, 3) 1728 ['re_lu_20[0][0]'] \n",
1054
+ " \n",
1055
+ " activation_2 (Activation) (None, 256, 256, 3) 0 ['conv2d_30[0][0]'] \n",
1056
+ " \n",
1057
+ "==================================================================================================\n",
1058
+ "Total params: 28,645,440\n",
1059
+ "Trainable params: 28,632,768\n",
1060
+ "Non-trainable params: 12,672\n",
1061
+ "__________________________________________________________________________________________________\n"
1062
+ ]
1063
+ }
1064
+ ],
1065
+ "source": [
1066
+ "generator_stage2 = build_stage2_generator()\n",
1067
+ "generator_stage2.summary()"
1068
+ ]
1069
+ },
1070
+ {
1071
+ "cell_type": "code",
1072
+ "execution_count": 31,
1073
+ "id": "41de758a",
1074
+ "metadata": {},
1075
+ "outputs": [],
1076
+ "source": [
1077
+ "############################################################\n",
1078
+ "# Stage 2 Discriminator Network\n",
1079
+ "############################################################\n",
1080
+ "\n",
1081
+ "def build_stage2_discriminator():\n",
1082
+ " \"\"\"Builds the Stage 2 Discriminator that uses the 256x256 resolution images from the generator\n",
1083
+ " and the compressed and spatially replicated embeddings.\n",
1084
+ "\n",
1085
+ " Returns:\n",
1086
+ " Stage 2 Discriminator Model for StackGAN.\n",
1087
+ " \"\"\"\n",
1088
+ " input_layer1 = Input(shape=(256, 256, 3))\n",
1089
+ "\n",
1090
+ " x = Conv2D(64, kernel_size=(4,4), padding='same', strides=2, use_bias=False,\n",
1091
+ " kernel_initializer='he_uniform')(input_layer1)\n",
1092
+ " x = LeakyReLU(alpha=0.2)(x)\n",
1093
+ "\n",
1094
+ " x = ConvBlock(x, 128)\n",
1095
+ " x = ConvBlock(x, 256)\n",
1096
+ " x = ConvBlock(x, 512)\n",
1097
+ " x = ConvBlock(x, 1024)\n",
1098
+ " x = ConvBlock(x, 2048)\n",
1099
+ " x = ConvBlock(x, 1024, (1,1), 1)\n",
1100
+ " x = ConvBlock(x, 512, (1,1), 1, False)\n",
1101
+ "\n",
1102
+ " x1 = ConvBlock(x, 128, (1,1), 1)\n",
1103
+ " x1 = ConvBlock(x1, 128, (3,3), 1)\n",
1104
+ " x1 = ConvBlock(x1, 512, (3,3), 1, False)\n",
1105
+ "\n",
1106
+ " x2 = add([x, x1])\n",
1107
+ " x2 = LeakyReLU(alpha=0.2)(x2)\n",
1108
+ "\n",
1109
+ " # Concatenate compressed and spatially replicated embedding\n",
1110
+ " input_layer2 = Input(shape=(4, 4, 128))\n",
1111
+ " concat = concatenate([x2, input_layer2])\n",
1112
+ "\n",
1113
+ " x3 = Conv2D(512, kernel_size=(1,1), strides=1, padding='same', kernel_initializer='he_uniform')(concat)\n",
1114
+ " x3 = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x3)\n",
1115
+ " x3 = LeakyReLU(alpha=0.2)(x3)\n",
1116
+ "\n",
1117
+ " # Flatten and add a FC layer\n",
1118
+ " x3 = Flatten()(x3)\n",
1119
+ " x3 = Dense(1)(x3)\n",
1120
+ " x3 = Activation('sigmoid')(x3)\n",
1121
+ "\n",
1122
+ " stage2_dis = Model(inputs=[input_layer1, input_layer2], outputs=[x3])\n",
1123
+ " return stage2_dis"
1124
+ ]
1125
+ },
1126
+ {
1127
+ "cell_type": "code",
1128
+ "execution_count": 32,
1129
+ "id": "7dbcbc4e",
1130
+ "metadata": {},
1131
+ "outputs": [
1132
+ {
1133
+ "name": "stdout",
1134
+ "output_type": "stream",
1135
+ "text": [
1136
+ "Model: \"model_4\"\n",
1137
+ "__________________________________________________________________________________________________\n",
1138
+ " Layer (type) Output Shape Param # Connected to \n",
1139
+ "==================================================================================================\n",
1140
+ " input_12 (InputLayer) [(None, 256, 256, 3 0 [] \n",
1141
+ " )] \n",
1142
+ " \n",
1143
+ " conv2d_31 (Conv2D) (None, 128, 128, 64 3072 ['input_12[0][0]'] \n",
1144
+ " ) \n",
1145
+ " \n",
1146
+ " leaky_re_lu_11 (LeakyReLU) (None, 128, 128, 64 0 ['conv2d_31[0][0]'] \n",
1147
+ " ) \n",
1148
+ " \n",
1149
+ " conv2d_32 (Conv2D) (None, 64, 64, 128) 131072 ['leaky_re_lu_11[0][0]'] \n",
1150
+ " \n",
1151
+ " batch_normalization_26 (BatchN (None, 64, 64, 128) 512 ['conv2d_32[0][0]'] \n",
1152
+ " ormalization) \n",
1153
+ " \n",
1154
+ " leaky_re_lu_12 (LeakyReLU) (None, 64, 64, 128) 0 ['batch_normalization_26[0][0]'] \n",
1155
+ " \n",
1156
+ " conv2d_33 (Conv2D) (None, 32, 32, 256) 524288 ['leaky_re_lu_12[0][0]'] \n",
1157
+ " \n",
1158
+ " batch_normalization_27 (BatchN (None, 32, 32, 256) 1024 ['conv2d_33[0][0]'] \n",
1159
+ " ormalization) \n",
1160
+ " \n",
1161
+ " leaky_re_lu_13 (LeakyReLU) (None, 32, 32, 256) 0 ['batch_normalization_27[0][0]'] \n",
1162
+ " \n",
1163
+ " conv2d_34 (Conv2D) (None, 16, 16, 512) 2097152 ['leaky_re_lu_13[0][0]'] \n",
1164
+ " \n",
1165
+ " batch_normalization_28 (BatchN (None, 16, 16, 512) 2048 ['conv2d_34[0][0]'] \n",
1166
+ " ormalization) \n",
1167
+ " \n",
1168
+ " leaky_re_lu_14 (LeakyReLU) (None, 16, 16, 512) 0 ['batch_normalization_28[0][0]'] \n",
1169
+ " \n",
1170
+ " conv2d_35 (Conv2D) (None, 8, 8, 1024) 8388608 ['leaky_re_lu_14[0][0]'] \n",
1171
+ " \n",
1172
+ " batch_normalization_29 (BatchN (None, 8, 8, 1024) 4096 ['conv2d_35[0][0]'] \n",
1173
+ " ormalization) \n",
1174
+ " \n",
1175
+ " leaky_re_lu_15 (LeakyReLU) (None, 8, 8, 1024) 0 ['batch_normalization_29[0][0]'] \n",
1176
+ " \n",
1177
+ " conv2d_36 (Conv2D) (None, 4, 4, 2048) 33554432 ['leaky_re_lu_15[0][0]'] \n",
1178
+ " \n",
1179
+ " batch_normalization_30 (BatchN (None, 4, 4, 2048) 8192 ['conv2d_36[0][0]'] \n",
1180
+ " ormalization) \n",
1181
+ " \n",
1182
+ " leaky_re_lu_16 (LeakyReLU) (None, 4, 4, 2048) 0 ['batch_normalization_30[0][0]'] \n",
1183
+ " \n",
1184
+ " conv2d_37 (Conv2D) (None, 4, 4, 1024) 2097152 ['leaky_re_lu_16[0][0]'] \n",
1185
+ " \n",
1186
+ " batch_normalization_31 (BatchN (None, 4, 4, 1024) 4096 ['conv2d_37[0][0]'] \n",
1187
+ " ormalization) \n",
1188
+ " \n",
1189
+ " leaky_re_lu_17 (LeakyReLU) (None, 4, 4, 1024) 0 ['batch_normalization_31[0][0]'] \n",
1190
+ " \n",
1191
+ " conv2d_38 (Conv2D) (None, 4, 4, 512) 524288 ['leaky_re_lu_17[0][0]'] \n",
1192
+ " \n",
1193
+ " batch_normalization_32 (BatchN (None, 4, 4, 512) 2048 ['conv2d_38[0][0]'] \n",
1194
+ " ormalization) \n",
1195
+ " \n",
1196
+ " conv2d_39 (Conv2D) (None, 4, 4, 128) 65536 ['batch_normalization_32[0][0]'] \n",
1197
+ " \n",
1198
+ " batch_normalization_33 (BatchN (None, 4, 4, 128) 512 ['conv2d_39[0][0]'] \n",
1199
+ " ormalization) \n",
1200
+ " \n",
1201
+ " leaky_re_lu_18 (LeakyReLU) (None, 4, 4, 128) 0 ['batch_normalization_33[0][0]'] \n",
1202
+ " \n",
1203
+ " conv2d_40 (Conv2D) (None, 4, 4, 128) 147456 ['leaky_re_lu_18[0][0]'] \n",
1204
+ " \n",
1205
+ " batch_normalization_34 (BatchN (None, 4, 4, 128) 512 ['conv2d_40[0][0]'] \n",
1206
+ " ormalization) \n",
1207
+ " \n",
1208
+ " leaky_re_lu_19 (LeakyReLU) (None, 4, 4, 128) 0 ['batch_normalization_34[0][0]'] \n",
1209
+ " \n",
1210
+ " conv2d_41 (Conv2D) (None, 4, 4, 512) 589824 ['leaky_re_lu_19[0][0]'] \n",
1211
+ " \n",
1212
+ " batch_normalization_35 (BatchN (None, 4, 4, 512) 2048 ['conv2d_41[0][0]'] \n",
1213
+ " ormalization) \n",
1214
+ " \n",
1215
+ " add_4 (Add) (None, 4, 4, 512) 0 ['batch_normalization_32[0][0]', \n",
1216
+ " 'batch_normalization_35[0][0]'] \n",
1217
+ " \n",
1218
+ " leaky_re_lu_20 (LeakyReLU) (None, 4, 4, 512) 0 ['add_4[0][0]'] \n",
1219
+ " \n"
1220
+ ]
1221
+ },
1222
+ {
1223
+ "name": "stdout",
1224
+ "output_type": "stream",
1225
+ "text": [
1226
+ " input_13 (InputLayer) [(None, 4, 4, 128)] 0 [] \n",
1227
+ " \n",
1228
+ " concatenate_2 (Concatenate) (None, 4, 4, 640) 0 ['leaky_re_lu_20[0][0]', \n",
1229
+ " 'input_13[0][0]'] \n",
1230
+ " \n",
1231
+ " conv2d_42 (Conv2D) (None, 4, 4, 512) 328192 ['concatenate_2[0][0]'] \n",
1232
+ " \n",
1233
+ " batch_normalization_36 (BatchN (None, 4, 4, 512) 2048 ['conv2d_42[0][0]'] \n",
1234
+ " ormalization) \n",
1235
+ " \n",
1236
+ " leaky_re_lu_21 (LeakyReLU) (None, 4, 4, 512) 0 ['batch_normalization_36[0][0]'] \n",
1237
+ " \n",
1238
+ " flatten_1 (Flatten) (None, 8192) 0 ['leaky_re_lu_21[0][0]'] \n",
1239
+ " \n",
1240
+ " dense_4 (Dense) (None, 1) 8193 ['flatten_1[0][0]'] \n",
1241
+ " \n",
1242
+ " activation_3 (Activation) (None, 1) 0 ['dense_4[0][0]'] \n",
1243
+ " \n",
1244
+ "==================================================================================================\n",
1245
+ "Total params: 48,486,401\n",
1246
+ "Trainable params: 48,472,833\n",
1247
+ "Non-trainable params: 13,568\n",
1248
+ "__________________________________________________________________________________________________\n"
1249
+ ]
1250
+ }
1251
+ ],
1252
+ "source": [
1253
+ "discriminator_stage2 = build_stage2_discriminator()\n",
1254
+ "discriminator_stage2.summary()"
1255
+ ]
1256
+ },
1257
+ {
1258
+ "cell_type": "code",
1259
+ "execution_count": 33,
1260
+ "id": "7131179e",
1261
+ "metadata": {},
1262
+ "outputs": [],
1263
+ "source": [
1264
+ "############################################################\n",
1265
+ "# Stage 2 Adversarial Model\n",
1266
+ "############################################################\n",
1267
+ "\n",
1268
+ "def stage2_adversarial_network(stage2_disc, stage2_gen, stage1_gen):\n",
1269
+ " \"\"\"Stage 2 Adversarial Network.\n",
1270
+ "\n",
1271
+ " Args:\n",
1272
+ " stage2_disc: Stage 2 Discriminator Model.\n",
1273
+ " stage2_gen: Stage 2 Generator Model.\n",
1274
+ " stage1_gen: Stage 1 Generator Model.\n",
1275
+ "\n",
1276
+ " Returns:\n",
1277
+ " Stage 2 Adversarial network.\n",
1278
+ " \"\"\"\n",
1279
+ " conditioned_embedding = Input(shape=(1024, ))\n",
1280
+ " latent_space = Input(shape=(100, ))\n",
1281
+ " compressed_replicated = Input(shape=(4, 4, 128))\n",
1282
+ " \n",
1283
+ " #the discriminator is trained separately and stage1_gen already trained, and this is the reason why we freeze its layers by setting the property trainable=false\n",
1284
+ " input_images, ca = stage1_gen([conditioned_embedding, latent_space])\n",
1285
+ " stage2_disc.trainable = False\n",
1286
+ " stage1_gen.trainable = False\n",
1287
+ "\n",
1288
+ " images, ca2 = stage2_gen([conditioned_embedding, input_images])\n",
1289
+ " probability = stage2_disc([images, compressed_replicated])\n",
1290
+ "\n",
1291
+ " return Model(inputs=[conditioned_embedding, latent_space, compressed_replicated],\n",
1292
+ " outputs=[probability, ca2])"
1293
+ ]
1294
+ },
1295
+ {
1296
+ "cell_type": "code",
1297
+ "execution_count": 34,
1298
+ "id": "a324bec8",
1299
+ "metadata": {},
1300
+ "outputs": [
1301
+ {
1302
+ "name": "stdout",
1303
+ "output_type": "stream",
1304
+ "text": [
1305
+ "Model: \"model_5\"\n",
1306
+ "__________________________________________________________________________________________________\n",
1307
+ " Layer (type) Output Shape Param # Connected to \n",
1308
+ "==================================================================================================\n",
1309
+ " input_14 (InputLayer) [(None, 1024)] 0 [] \n",
1310
+ " \n",
1311
+ " input_15 (InputLayer) [(None, 100)] 0 [] \n",
1312
+ " \n",
1313
+ " model (Functional) [(None, 64, 64, 3), 10270400 ['input_14[0][0]', \n",
1314
+ " (None, 256)] 'input_15[0][0]'] \n",
1315
+ " \n",
1316
+ " model_3 (Functional) [(None, 256, 256, 3 28645440 ['input_14[0][0]', \n",
1317
+ " ), 'model[1][0]'] \n",
1318
+ " (None, 256)] \n",
1319
+ " \n",
1320
+ " input_16 (InputLayer) [(None, 4, 4, 128)] 0 [] \n",
1321
+ " \n",
1322
+ " model_4 (Functional) (None, 1) 48486401 ['model_3[0][0]', \n",
1323
+ " 'input_16[0][0]'] \n",
1324
+ " \n",
1325
+ "==================================================================================================\n",
1326
+ "Total params: 87,402,241\n",
1327
+ "Trainable params: 28,632,768\n",
1328
+ "Non-trainable params: 58,769,473\n",
1329
+ "__________________________________________________________________________________________________\n"
1330
+ ]
1331
+ }
1332
+ ],
1333
+ "source": [
1334
+ "adversarial_stage2 = stage2_adversarial_network(discriminator_stage2, generator_stage2, generator)\n",
1335
+ "adversarial_stage2.summary()"
1336
+ ]
1337
+ },
1338
+ {
1339
+ "cell_type": "code",
1340
+ "execution_count": 35,
1341
+ "id": "75ce4927",
1342
+ "metadata": {},
1343
+ "outputs": [],
1344
+ "source": [
1345
+ "class StackGanStage2(object):\n",
1346
+ " \"\"\"StackGAN Stage 2 class.\n",
1347
+ "\n",
1348
+ " Args:\n",
1349
+ " epochs: Number of epochs\n",
1350
+ " z_dim: Latent space dimensions\n",
1351
+ " batch_size: Batch Size\n",
1352
+ " enable_function: If True, training function is decorated with tf.function\n",
1353
+ " stage2_generator_lr: Learning rate for stage 2 generator\n",
1354
+ " stage2_discriminator_lr: Learning rate for stage 2 discriminator\n",
1355
+ " \"\"\"\n",
1356
+ " def __init__(self, epochs=500, z_dim=100, batch_size=64, enable_function=True, stage2_generator_lr=0.0002, stage2_discriminator_lr=0.0002):\n",
1357
+ " self.epochs = epochs\n",
1358
+ " self.z_dim = z_dim\n",
1359
+ " self.enable_function = enable_function\n",
1360
+ " self.stage1_generator_lr = stage2_generator_lr\n",
1361
+ " self.stage1_discriminator_lr = stage2_discriminator_lr\n",
1362
+ " self.low_image_size = 64\n",
1363
+ " self.high_image_size = 256\n",
1364
+ " self.conditioning_dim = 128\n",
1365
+ " self.batch_size = batch_size\n",
1366
+ " self.stage2_generator_optimizer = Adam(lr=stage2_generator_lr, beta_1=0.5, beta_2=0.999)\n",
1367
+ " self.stage2_discriminator_optimizer = Adam(lr=stage2_discriminator_lr, beta_1=0.5, beta_2=0.999)\n",
1368
+ " self.stage1_generator = build_stage1_generator()\n",
1369
+ " self.stage1_generator.compile(loss='binary_crossentropy', optimizer=self.stage2_generator_optimizer)\n",
1370
+ " self.stage1_generator.load_weights('weights/stage1_gen.h5')\n",
1371
+ " self.stage2_generator = build_stage2_generator()\n",
1372
+ " self.stage2_generator.compile(loss='binary_crossentropy', optimizer=self.stage2_generator_optimizer)\n",
1373
+ "\n",
1374
+ " self.stage2_discriminator = build_stage2_discriminator()\n",
1375
+ " self.stage2_discriminator.compile(loss='binary_crossentropy', optimizer=self.stage2_discriminator_optimizer)\n",
1376
+ "\n",
1377
+ " self.ca_network = build_ca_network()\n",
1378
+ " self.ca_network.compile(loss='binary_crossentropy', optimizer='Adam')\n",
1379
+ "\n",
1380
+ " self.embedding_compressor = build_embedding_compressor()\n",
1381
+ " self.embedding_compressor.compile(loss='binary_crossentropy', optimizer='Adam')\n",
1382
+ "\n",
1383
+ " self.stage2_adversarial = stage2_adversarial_network(self.stage2_discriminator, self.stage2_generator, self.stage1_generator)\n",
1384
+ " self.stage2_adversarial.compile(loss=['binary_crossentropy', adversarial_loss], loss_weights=[1, 2.0], optimizer=self.stage2_generator_optimizer)\t\n",
1385
+ "\n",
1386
+ " self.checkpoint2 = tf.train.Checkpoint(\n",
1387
+ " generator_optimizer=self.stage2_generator_optimizer,\n",
1388
+ " discriminator_optimizer=self.stage2_discriminator_optimizer,\n",
1389
+ " generator=self.stage2_generator,\n",
1390
+ " discriminator=self.stage2_discriminator,\n",
1391
+ " generator1=self.stage1_generator)\n",
1392
+ "\n",
1393
+ " def visualize_stage2(self):\n",
1394
+ " \"\"\"Running Tensorboard visualizations.\n",
1395
+ " \"\"\"\n",
1396
+ " tb = TensorBoard(log_dir=\"logs/\".format(time.time()))\n",
1397
+ " tb.set_model(self.stage2_generator)\n",
1398
+ " tb.set_model(self.stage2_discriminator)\n",
1399
+ "\n",
1400
+ " def train_stage2(self):\n",
1401
+ " \"\"\"Trains Stage 2 StackGAN.\n",
1402
+ " \"\"\"\n",
1403
+ " x_high_train, y_high_train, high_train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,\n",
1404
+ " dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(256, 256))\n",
1405
+ "\n",
1406
+ " x_high_test, y_high_test, high_test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test, \n",
1407
+ " dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(256, 256))\n",
1408
+ "\n",
1409
+ " x_low_train, y_low_train, low_train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,\n",
1410
+ " dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(64, 64))\n",
1411
+ "\n",
1412
+ " x_low_test, y_low_test, low_test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test, \n",
1413
+ " dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(64, 64))\n",
1414
+ "\n",
1415
+ " real = np.ones((self.batch_size, 1), dtype='float') * 0.9\n",
1416
+ " fake = np.zeros((self.batch_size, 1), dtype='float') * 0.1\n",
1417
+ "\n",
1418
+ " for epoch in range(self.epochs):\n",
1419
+ " print(f'Epoch: {epoch}')\n",
1420
+ "\n",
1421
+ " gen_loss = []\n",
1422
+ " disc_loss = []\n",
1423
+ "\n",
1424
+ " num_batches = int(x_high_train.shape[0] / self.batch_size)\n",
1425
+ "\n",
1426
+ " for i in range(num_batches):\n",
1427
+ "\n",
1428
+ " latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))\n",
1429
+ " embedding_text = high_train_embeds[i * self.batch_size:(i + 1) * self.batch_size]\n",
1430
+ " compressed_embedding = self.embedding_compressor.predict_on_batch(embedding_text)\n",
1431
+ " compressed_embedding = np.reshape(compressed_embedding, (-1, 1, 1, self.conditioning_dim))\n",
1432
+ " compressed_embedding = np.tile(compressed_embedding, (1, 4, 4, 1))\n",
1433
+ "\n",
1434
+ " image_batch = x_high_train[i * self.batch_size:(i+1) * self.batch_size]\n",
1435
+ " image_batch = (image_batch - 127.5) / 127.5\n",
1436
+ " \n",
1437
+ " low_res_fakes, _ = self.stage1_generator.predict([embedding_text, latent_space], verbose=3)\n",
1438
+ " high_res_fakes, _ = self.stage2_generator.predict([embedding_text, low_res_fakes], verbose=3)\n",
1439
+ "\n",
1440
+ " discriminator_loss = self.stage2_discriminator.train_on_batch([image_batch, compressed_embedding],\n",
1441
+ " np.reshape(real, (self.batch_size, 1)))\n",
1442
+ "\n",
1443
+ " discriminator_loss_gen = self.stage2_discriminator.train_on_batch([high_res_fakes, compressed_embedding],\n",
1444
+ " np.reshape(fake, (self.batch_size, 1)))\n",
1445
+ "\n",
1446
+ " discriminator_loss_fake = self.stage2_discriminator.train_on_batch([image_batch[:(self.batch_size-1)], compressed_embedding[1:]],\n",
1447
+ " np.reshape(fake[1:], (self.batch_size - 1, 1)))\n",
1448
+ "\n",
1449
+ " d_loss = 0.5 * np.add(discriminator_loss, 0.5 * np.add(discriminator_loss_gen, discriminator_loss_fake))\n",
1450
+ " disc_loss.append(d_loss)\n",
1451
+ "\n",
1452
+ " print(f'Discriminator Loss: {d_loss}')\n",
1453
+ "\n",
1454
+ " g_loss = self.stage2_adversarial.train_on_batch([embedding_text, latent_space, compressed_embedding],\n",
1455
+ " [K.ones((self.batch_size, 1)) * 0.9, K.ones((self.batch_size, 256)) * 0.9])\n",
1456
+ " gen_loss.append(g_loss)\n",
1457
+ "\n",
1458
+ " print(f'Generator Loss: {g_loss}')\n",
1459
+ "\n",
1460
+ " if epoch % 5 == 0:\n",
1461
+ " latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))\n",
1462
+ " embedding_batch = high_test_embeds[0 : self.batch_size]\n",
1463
+ "\n",
1464
+ " low_fake_images, _ = self.stage1_generator.predict([embedding_batch, latent_space], verbose=3)\n",
1465
+ " high_fake_images, _ = self.stage2_generator.predict([embedding_batch, low_fake_images], verbose=3)\n",
1466
+ "\n",
1467
+ " for i, image in enumerate(high_fake_images[:10]):\n",
1468
+ " save_image(image, f'results_stage2/gen_{epoch}_{i}.png')\n",
1469
+ "\n",
1470
+ " if epoch % 10 == 0:\n",
1471
+ " self.stage2_generator.save_weights('weights/stage2_gen.h5')\n",
1472
+ " self.stage2_discriminator.save_weights(\"weights/stage2_disc.h5\")\n",
1473
+ " self.ca_network.save_weights('weights/stage2_ca.h5')\n",
1474
+ " self.embedding_compressor.save_weights('weights/stage2_embco.h5')\n",
1475
+ " self.stage2_adversarial.save_weights('weights/stage2_adv.h5')\n",
1476
+ "\n",
1477
+ " self.stage2_generator.save_weights('weights/stage2_gen.h5')\n",
1478
+ " self.stage2_discriminator.save_weights(\"weights/stage2_disc.h5\")"
1479
+ ]
1480
+ },
1481
+ {
1482
+ "cell_type": "code",
1483
+ "execution_count": null,
1484
+ "id": "0a91a164",
1485
+ "metadata": {},
1486
+ "outputs": [],
1487
+ "source": [
1488
+ "stage2 = StackGanStage2()\n",
1489
+ "stage2.train_stage2()"
1490
+ ]
1491
+ }
1492
+ ],
1493
+ "metadata": {
1494
+ "kernelspec": {
1495
+ "display_name": "Python 3 (ipykernel)",
1496
+ "language": "python",
1497
+ "name": "python3"
1498
+ },
1499
+ "language_info": {
1500
+ "codemirror_mode": {
1501
+ "name": "ipython",
1502
+ "version": 3
1503
+ },
1504
+ "file_extension": ".py",
1505
+ "mimetype": "text/x-python",
1506
+ "name": "python",
1507
+ "nbconvert_exporter": "python",
1508
+ "pygments_lexer": "ipython3",
1509
+ "version": "3.10.9"
1510
+ }
1511
+ },
1512
+ "nbformat": 4,
1513
+ "nbformat_minor": 5
1514
+ }