I2VGen-XL / test.py
kevinwang676's picture
Update test.py
22cf660 verified
raw
history blame
2.3 kB
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization
from tensorflow.keras import regularizers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
with tf.device('/gpu:0'):
# this is the size of our encoded representations
encoding_dim1 = 500
encoding_dim2 = 200
lambda_act = 0.0001
lambda_weight = 0.001
# this is our input placeholder
input_data = Input(shape=(num_in_neurons,))
# first encoded representation of the input
encoded = Dense(encoding_dim1, activation='relu', activity_regularizer=regularizers.l1(lambda_act), kernel_regularizer=regularizers.l2(lambda_weight), name='encoder1')(input_data)
encoded = BatchNormalization()(encoded)
encoded = Dropout(0.5)(encoded)
# second encoded representation of the input
encoded = Dense(encoding_dim2, activation='relu', activity_regularizer=regularizers.l1(lambda_act), kernel_regularizer=regularizers.l2(lambda_weight), name='encoder2')(encoded)
encoded = BatchNormalization()(encoded)
encoded = Dropout(0.5)(encoded)
# first lossy reconstruction of the input
decoded = Dense(encoding_dim1, activation='relu', name='decoder1')(encoded)
decoded = BatchNormalization()(decoded)
# the final lossy reconstruction of the input
decoded = Dense(num_in_neurons, activation='sigmoid', name='decoder2')(decoded)
# this model maps an input to its reconstruction
autoencoder = Model(inputs=input_data, outputs=decoded)
autoencoder.compile(optimizer=Adam(), loss='mse')
# setup callbacks
callbacks = [
EarlyStopping(monitor='val_loss', patience=5, verbose=1),
ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, verbose=1)
]
# training
print('Training the autoencoder')
autoencoder.fit(x_train_noisy, x_train,
epochs=50,
batch_size=8,
shuffle=True,
validation_data=(x_test_noisy, x_test),
callbacks=callbacks)
# Load best model
autoencoder.load_weights('best_model.h5')
# Freeze the weights
autoencoder.trainable = False