Spaces:
Runtime error
Runtime error
Commit
·
0f84baa
1
Parent(s):
659a217
added zero-dce model
Browse files
enhance_me/zero_dce/dataloader.py
CHANGED
|
@@ -17,7 +17,7 @@ class UnpairedLowLightDataset:
|
|
| 17 |
self.apply_random_vertical_flip = apply_random_vertical_flip
|
| 18 |
self.apply_random_rotation = apply_random_rotation
|
| 19 |
|
| 20 |
-
def
|
| 21 |
image = tf.io.read_file(image_path)
|
| 22 |
image = tf.image.decode_png(image, channels=3)
|
| 23 |
image = image / 255.0
|
|
@@ -25,7 +25,7 @@ class UnpairedLowLightDataset:
|
|
| 25 |
|
| 26 |
def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
|
| 27 |
dataset = tf.data.Dataset.from_tensor_slices((images))
|
| 28 |
-
dataset = dataset.map(self.
|
| 29 |
dataset = dataset.map(
|
| 30 |
self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE
|
| 31 |
)
|
|
|
|
| 17 |
self.apply_random_vertical_flip = apply_random_vertical_flip
|
| 18 |
self.apply_random_rotation = apply_random_rotation
|
| 19 |
|
| 20 |
+
def _load_data(self, image_path):
|
| 21 |
image = tf.io.read_file(image_path)
|
| 22 |
image = tf.image.decode_png(image, channels=3)
|
| 23 |
image = image / 255.0
|
|
|
|
| 25 |
|
| 26 |
def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
|
| 27 |
dataset = tf.data.Dataset.from_tensor_slices((images))
|
| 28 |
+
dataset = dataset.map(self._load_data, num_parallel_calls=tf.data.AUTOTUNE)
|
| 29 |
dataset = dataset.map(
|
| 30 |
self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE
|
| 31 |
)
|
enhance_me/zero_dce/losses/__init__.py
CHANGED
|
@@ -5,7 +5,11 @@ from .spatial_constancy import SpatialConsistencyLoss
|
|
| 5 |
|
| 6 |
def color_constancy_loss(x):
|
| 7 |
mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
|
| 8 |
-
mean_r, mean_g, mean_b =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
diff_rg = tf.square(mean_r - mean_g)
|
| 10 |
diff_rb = tf.square(mean_r - mean_b)
|
| 11 |
diff_gb = tf.square(mean_b - mean_g)
|
|
|
|
| 5 |
|
| 6 |
def color_constancy_loss(x):
|
| 7 |
mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
|
| 8 |
+
mean_r, mean_g, mean_b = (
|
| 9 |
+
mean_rgb[:, :, :, 0],
|
| 10 |
+
mean_rgb[:, :, :, 1],
|
| 11 |
+
mean_rgb[:, :, :, 2],
|
| 12 |
+
)
|
| 13 |
diff_rg = tf.square(mean_r - mean_g)
|
| 14 |
diff_rb = tf.square(mean_r - mean_b)
|
| 15 |
diff_gb = tf.square(mean_b - mean_g)
|
enhance_me/zero_dce/models/zero_dce.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras import optimizers, Model
|
| 3 |
+
|
| 4 |
+
from .dce_net import build_dce_net
|
| 5 |
+
from ..dataloader import UnpairedLowLightDataset
|
| 6 |
+
from ..losses import (
|
| 7 |
+
color_constancy_loss,
|
| 8 |
+
exposure_loss,
|
| 9 |
+
illumination_smoothness_loss,
|
| 10 |
+
SpatialConsistencyLoss,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ZeroDCE(Model):
|
| 15 |
+
def __init__(self, **kwargs):
|
| 16 |
+
super(ZeroDCE, self).__init__(**kwargs)
|
| 17 |
+
self.dce_model = build_dce_net()
|
| 18 |
+
|
| 19 |
+
def compile(self, learning_rate, **kwargs):
|
| 20 |
+
super(ZeroDCE, self).compile(**kwargs)
|
| 21 |
+
self.optimizer = optimizers.Adam(learning_rate=learning_rate)
|
| 22 |
+
self.spatial_constancy_loss = SpatialConsistencyLoss(reduction="none")
|
| 23 |
+
|
| 24 |
+
def get_enhanced_image(self, data, output):
|
| 25 |
+
r1 = output[:, :, :, :3]
|
| 26 |
+
r2 = output[:, :, :, 3:6]
|
| 27 |
+
r3 = output[:, :, :, 6:9]
|
| 28 |
+
r4 = output[:, :, :, 9:12]
|
| 29 |
+
r5 = output[:, :, :, 12:15]
|
| 30 |
+
r6 = output[:, :, :, 15:18]
|
| 31 |
+
r7 = output[:, :, :, 18:21]
|
| 32 |
+
r8 = output[:, :, :, 21:24]
|
| 33 |
+
x = data + r1 * (tf.square(data) - data)
|
| 34 |
+
x = x + r2 * (tf.square(x) - x)
|
| 35 |
+
x = x + r3 * (tf.square(x) - x)
|
| 36 |
+
enhanced_image = x + r4 * (tf.square(x) - x)
|
| 37 |
+
x = enhanced_image + r5 * (tf.square(enhanced_image) - enhanced_image)
|
| 38 |
+
x = x + r6 * (tf.square(x) - x)
|
| 39 |
+
x = x + r7 * (tf.square(x) - x)
|
| 40 |
+
enhanced_image = x + r8 * (tf.square(x) - x)
|
| 41 |
+
return enhanced_image
|
| 42 |
+
|
| 43 |
+
def call(self, data):
|
| 44 |
+
dce_net_output = self.dce_model(data)
|
| 45 |
+
return self.get_enhanced_image(data, dce_net_output)
|
| 46 |
+
|
| 47 |
+
def compute_losses(self, data, output):
|
| 48 |
+
enhanced_image = self.get_enhanced_image(data, output)
|
| 49 |
+
loss_illumination = 200 * illumination_smoothness_loss(output)
|
| 50 |
+
loss_spatial_constancy = tf.reduce_mean(
|
| 51 |
+
self.spatial_constancy_loss(enhanced_image, data)
|
| 52 |
+
)
|
| 53 |
+
loss_color_constancy = 5 * tf.reduce_mean(color_constancy_loss(enhanced_image))
|
| 54 |
+
loss_exposure = 10 * tf.reduce_mean(exposure_loss(enhanced_image))
|
| 55 |
+
total_loss = (
|
| 56 |
+
loss_illumination
|
| 57 |
+
+ loss_spatial_constancy
|
| 58 |
+
+ loss_color_constancy
|
| 59 |
+
+ loss_exposure
|
| 60 |
+
)
|
| 61 |
+
return {
|
| 62 |
+
"total_loss": total_loss,
|
| 63 |
+
"illumination_smoothness_loss": loss_illumination,
|
| 64 |
+
"spatial_constancy_loss": loss_spatial_constancy,
|
| 65 |
+
"color_constancy_loss": loss_color_constancy,
|
| 66 |
+
"exposure_loss": loss_exposure,
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
def train_step(self, data):
|
| 70 |
+
with tf.GradientTape() as tape:
|
| 71 |
+
output = self.dce_model(data)
|
| 72 |
+
losses = self.compute_losses(data, output)
|
| 73 |
+
gradients = tape.gradient(
|
| 74 |
+
losses["total_loss"], self.dce_model.trainable_weights
|
| 75 |
+
)
|
| 76 |
+
self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights))
|
| 77 |
+
return losses
|
| 78 |
+
|
| 79 |
+
def test_step(self, data):
|
| 80 |
+
output = self.dce_model(data)
|
| 81 |
+
return self.compute_losses(data, output)
|
| 82 |
+
|
| 83 |
+
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
|
| 84 |
+
"""While saving the weights, we simply save the weights of the DCE-Net"""
|
| 85 |
+
self.dce_model.save_weights(
|
| 86 |
+
filepath, overwrite=overwrite, save_format=save_format, options=options
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
|
| 90 |
+
"""While loading the weights, we simply load the weights of the DCE-Net"""
|
| 91 |
+
self.dce_model.load_weights(
|
| 92 |
+
filepath=filepath,
|
| 93 |
+
by_name=by_name,
|
| 94 |
+
skip_mismatch=skip_mismatch,
|
| 95 |
+
options=options,
|
| 96 |
+
)
|