import tensorflow as tf from tensorflow.keras import optimizers, Model from .dce_net import build_dce_net from ..dataloader import UnpairedLowLightDataset from ..losses import ( color_constancy_loss, exposure_loss, illumination_smoothness_loss, SpatialConsistencyLoss, ) class ZeroDCE(Model): def __init__(self, **kwargs): super(ZeroDCE, self).__init__(**kwargs) self.dce_model = build_dce_net() def compile(self, learning_rate, **kwargs): super(ZeroDCE, self).compile(**kwargs) self.optimizer = optimizers.Adam(learning_rate=learning_rate) self.spatial_constancy_loss = SpatialConsistencyLoss(reduction="none") def get_enhanced_image(self, data, output): r1 = output[:, :, :, :3] r2 = output[:, :, :, 3:6] r3 = output[:, :, :, 6:9] r4 = output[:, :, :, 9:12] r5 = output[:, :, :, 12:15] r6 = output[:, :, :, 15:18] r7 = output[:, :, :, 18:21] r8 = output[:, :, :, 21:24] x = data + r1 * (tf.square(data) - data) x = x + r2 * (tf.square(x) - x) x = x + r3 * (tf.square(x) - x) enhanced_image = x + r4 * (tf.square(x) - x) x = enhanced_image + r5 * (tf.square(enhanced_image) - enhanced_image) x = x + r6 * (tf.square(x) - x) x = x + r7 * (tf.square(x) - x) enhanced_image = x + r8 * (tf.square(x) - x) return enhanced_image def call(self, data): dce_net_output = self.dce_model(data) return self.get_enhanced_image(data, dce_net_output) def compute_losses(self, data, output): enhanced_image = self.get_enhanced_image(data, output) loss_illumination = 200 * illumination_smoothness_loss(output) loss_spatial_constancy = tf.reduce_mean( self.spatial_constancy_loss(enhanced_image, data) ) loss_color_constancy = 5 * tf.reduce_mean(color_constancy_loss(enhanced_image)) loss_exposure = 10 * tf.reduce_mean(exposure_loss(enhanced_image)) total_loss = ( loss_illumination + loss_spatial_constancy + loss_color_constancy + loss_exposure ) return { "total_loss": total_loss, "illumination_smoothness_loss": loss_illumination, "spatial_constancy_loss": loss_spatial_constancy, "color_constancy_loss": loss_color_constancy, "exposure_loss": loss_exposure, } def train_step(self, data): with tf.GradientTape() as tape: output = self.dce_model(data) losses = self.compute_losses(data, output) gradients = tape.gradient( losses["total_loss"], self.dce_model.trainable_weights ) self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights)) return losses def test_step(self, data): output = self.dce_model(data) return self.compute_losses(data, output) def save_weights(self, filepath, overwrite=True, save_format=None, options=None): """While saving the weights, we simply save the weights of the DCE-Net""" self.dce_model.save_weights( filepath, overwrite=overwrite, save_format=save_format, options=options ) def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None): """While loading the weights, we simply load the weights of the DCE-Net""" self.dce_model.load_weights( filepath=filepath, by_name=by_name, skip_mismatch=skip_mismatch, options=options, )