Spaces:
Runtime error
Runtime error
Merge pull request #2 from soumik12345/zero-dce
Browse filesZero-reference Deep Curve Estimation for Low-light Image Enhancement
- .gitignore +3 -1
- Dockerfile +17 -0
- README.md +11 -1
- app.py +49 -11
- enhance_me/__init__.py +2 -0
- enhance_me/augmentation.py +35 -0
- enhance_me/commons.py +15 -0
- enhance_me/zero_dce/__init__.py +1 -0
- enhance_me/zero_dce/dataloader.py +79 -0
- enhance_me/zero_dce/dce_net.py +31 -0
- enhance_me/zero_dce/losses/__init__.py +36 -0
- enhance_me/zero_dce/losses/spatial_constancy.py +63 -0
- enhance_me/zero_dce/zero_dce.py +183 -0
- notebooks/enhance_me_train.ipynb +103 -6
- test.py +4 -0
.gitignore
CHANGED
|
@@ -131,4 +131,6 @@ dmypy.json
|
|
| 131 |
# Datasets
|
| 132 |
datasets/
|
| 133 |
**.zip
|
| 134 |
-
**.h5
|
|
|
|
|
|
|
|
|
| 131 |
# Datasets
|
| 132 |
datasets/
|
| 133 |
**.zip
|
| 134 |
+
**.h5
|
| 135 |
+
lol_dataset_**
|
| 136 |
+
wandb**
|
Dockerfile
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pull Base Image
|
| 2 |
+
FROM tensorflow/tensorflow:latest-gpu-jupyter
|
| 3 |
+
|
| 4 |
+
# Set Working Directory
|
| 5 |
+
RUN mkdir /usr/src/enhance-me
|
| 6 |
+
WORKDIR /usr/src/enhance-me
|
| 7 |
+
|
| 8 |
+
# Set Environment Variables
|
| 9 |
+
ENV PYTHONDONTWRITEBYTECODE 1
|
| 10 |
+
ENV PYTHONUNBUFFERED 1
|
| 11 |
+
|
| 12 |
+
RUN pip install --upgrade pip setuptools wheel
|
| 13 |
+
RUN pip install gdown matplotlib streamlit tqdm wandb
|
| 14 |
+
|
| 15 |
+
COPY . /usr/src/enhance-me/
|
| 16 |
+
|
| 17 |
+
CMD ["jupyter", "notebook", "--port=8888", "--no-browser", "--ip=0.0.0.0", "--allow-root"]
|
README.md
CHANGED
|
@@ -8,4 +8,14 @@ app_file: app.py
|
|
| 8 |
pinned: false
|
| 9 |
---
|
| 10 |
|
| 11 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
pinned: false
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# Enhance Me
|
| 12 |
+
|
| 13 |
+
A unified platform for image enhancement.
|
| 14 |
+
|
| 15 |
+
## Usage
|
| 16 |
+
|
| 17 |
+
### Train using Docker
|
| 18 |
+
|
| 19 |
+
- Build image using `docker build -t enhance-image .`
|
| 20 |
+
|
| 21 |
+
- Run notebook using `docker run -it --gpus all -p 8888:8888 -v $(pwd):/usr/src/enhance-me enhance-image`
|
app.py
CHANGED
|
@@ -1,23 +1,36 @@
|
|
|
|
|
| 1 |
from PIL import Image
|
| 2 |
import streamlit as st
|
| 3 |
from tensorflow.keras import utils, backend
|
| 4 |
|
| 5 |
-
from enhance_me
|
| 6 |
|
| 7 |
|
| 8 |
def get_mirnet_object() -> MIRNet:
|
| 9 |
-
mirnet = MIRNet()
|
| 10 |
-
mirnet.build_model()
|
| 11 |
utils.get_file(
|
| 12 |
"weights_lol_128.h5",
|
| 13 |
"https://github.com/soumik12345/enhance-me/releases/download/v0.2/weights_lol_128.h5",
|
| 14 |
cache_dir=".",
|
| 15 |
cache_subdir="weights",
|
| 16 |
)
|
|
|
|
|
|
|
| 17 |
mirnet.load_weights("./weights/weights_lol_128.h5")
|
| 18 |
return mirnet
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def main():
|
| 22 |
st.markdown("# Enhance Me")
|
| 23 |
st.markdown("Made with :heart: by [geekyRakshit](http://github.com/soumik12345)")
|
|
@@ -30,14 +43,39 @@ def main():
|
|
| 30 |
if uploaded_file is not None:
|
| 31 |
original_image = Image.open(uploaded_file)
|
| 32 |
st.image(original_image, caption="original image")
|
| 33 |
-
st.sidebar.
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
import os
|
| 2 |
from PIL import Image
|
| 3 |
import streamlit as st
|
| 4 |
from tensorflow.keras import utils, backend
|
| 5 |
|
| 6 |
+
from enhance_me import MIRNet, ZeroDCE
|
| 7 |
|
| 8 |
|
| 9 |
def get_mirnet_object() -> MIRNet:
|
|
|
|
|
|
|
| 10 |
utils.get_file(
|
| 11 |
"weights_lol_128.h5",
|
| 12 |
"https://github.com/soumik12345/enhance-me/releases/download/v0.2/weights_lol_128.h5",
|
| 13 |
cache_dir=".",
|
| 14 |
cache_subdir="weights",
|
| 15 |
)
|
| 16 |
+
mirnet = MIRNet()
|
| 17 |
+
mirnet.build_model()
|
| 18 |
mirnet.load_weights("./weights/weights_lol_128.h5")
|
| 19 |
return mirnet
|
| 20 |
|
| 21 |
|
| 22 |
+
def get_zero_dce_object(model_alias: str) -> ZeroDCE:
|
| 23 |
+
utils.get_file(
|
| 24 |
+
f"{model_alias}.h5",
|
| 25 |
+
f"https://github.com/soumik12345/enhance-me/releases/download/v0.4/{model_alias}.h5",
|
| 26 |
+
cache_dir=".",
|
| 27 |
+
cache_subdir="weights",
|
| 28 |
+
)
|
| 29 |
+
dce = ZeroDCE()
|
| 30 |
+
dce.load_weights(os.path.join("./weights", f"{model_alias}.h5"))
|
| 31 |
+
return dce
|
| 32 |
+
|
| 33 |
+
|
| 34 |
def main():
|
| 35 |
st.markdown("# Enhance Me")
|
| 36 |
st.markdown("Made with :heart: by [geekyRakshit](http://github.com/soumik12345)")
|
|
|
|
| 43 |
if uploaded_file is not None:
|
| 44 |
original_image = Image.open(uploaded_file)
|
| 45 |
st.image(original_image, caption="original image")
|
| 46 |
+
model_option = st.sidebar.selectbox(
|
| 47 |
+
"Please select the model:",
|
| 48 |
+
(
|
| 49 |
+
"",
|
| 50 |
+
"MIRNet",
|
| 51 |
+
"Zero-DCE (dce_weights_lol_128)",
|
| 52 |
+
"Zero-DCE (dce_weights_lol_128_resize)",
|
| 53 |
+
"Zero-DCE (dce_weights_lol_256)",
|
| 54 |
+
"Zero-DCE (dce_weights_lol_256_resize)",
|
| 55 |
+
"Zero-DCE (dce_weights_unpaired_128)",
|
| 56 |
+
"Zero-DCE (dce_weights_unpaired_128_resize)",
|
| 57 |
+
"Zero-DCE (dce_weights_unpaired_256)",
|
| 58 |
+
"Zero-DCE (dce_weights_unpaired_256_resize)"
|
| 59 |
+
),
|
| 60 |
+
)
|
| 61 |
+
if model_option != "":
|
| 62 |
+
if model_option == "MIRNet":
|
| 63 |
+
st.sidebar.info("Loading MIRNet...")
|
| 64 |
+
mirnet = get_mirnet_object()
|
| 65 |
+
st.sidebar.info("Done!")
|
| 66 |
+
st.sidebar.info("Processing Image...")
|
| 67 |
+
enhanced_image = mirnet.infer(original_image)
|
| 68 |
+
st.sidebar.info("Done!")
|
| 69 |
+
st.image(enhanced_image, caption="enhanced image")
|
| 70 |
+
elif "Zero-DCE" in model_option:
|
| 71 |
+
model_alias = model_option[model_option.find("(") + 1: model_option.find(")")]
|
| 72 |
+
st.sidebar.info("Loading Zero-DCE...")
|
| 73 |
+
zero_dce = get_zero_dce_object(model_alias)
|
| 74 |
+
st.sidebar.info("Done!")
|
| 75 |
+
enhanced_image = zero_dce.infer(original_image)
|
| 76 |
+
st.sidebar.info("Done!")
|
| 77 |
+
st.image(enhanced_image, caption="enhanced image")
|
| 78 |
+
backend.clear_session()
|
| 79 |
|
| 80 |
|
| 81 |
if __name__ == "__main__":
|
enhance_me/__init__.py
CHANGED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mirnet import MIRNet
|
| 2 |
+
from .zero_dce import ZeroDCE
|
enhance_me/augmentation.py
CHANGED
|
@@ -49,3 +49,38 @@ class AugmentationFactory:
|
|
| 49 |
return tf.image.rot90(input_image, condition), tf.image.rot90(
|
| 50 |
enhanced_image, condition
|
| 51 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
return tf.image.rot90(input_image, condition), tf.image.rot90(
|
| 50 |
enhanced_image, condition
|
| 51 |
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class UnpairedAugmentationFactory:
|
| 55 |
+
def __init__(self, image_size) -> None:
|
| 56 |
+
self.image_size = image_size
|
| 57 |
+
|
| 58 |
+
def random_crop(self, image):
|
| 59 |
+
image_shape = tf.shape(image)[:2]
|
| 60 |
+
crop_w = tf.random.uniform(
|
| 61 |
+
shape=(), maxval=image_shape[1] - self.image_size + 1, dtype=tf.int32
|
| 62 |
+
)
|
| 63 |
+
crop_h = tf.random.uniform(
|
| 64 |
+
shape=(), maxval=image_shape[0] - self.image_size + 1, dtype=tf.int32
|
| 65 |
+
)
|
| 66 |
+
return image[
|
| 67 |
+
crop_h : crop_h + self.image_size, crop_w : crop_w + self.image_size
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
def random_horizontal_flip(self, image):
|
| 71 |
+
return tf.cond(
|
| 72 |
+
tf.random.uniform(shape=(), maxval=1) < 0.5,
|
| 73 |
+
lambda: image,
|
| 74 |
+
lambda: tf.image.flip_left_right(image),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def random_vertical_flip(self, image):
|
| 78 |
+
return tf.cond(
|
| 79 |
+
tf.random.uniform(shape=(), maxval=1) < 0.5,
|
| 80 |
+
lambda: image,
|
| 81 |
+
lambda: tf.image.flip_up_down(image),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def random_rotate(self, image):
|
| 85 |
+
condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
|
| 86 |
+
return tf.image.rot90(image, condition)
|
enhance_me/commons.py
CHANGED
|
@@ -61,3 +61,18 @@ def download_lol_dataset():
|
|
| 61 |
test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
|
| 62 |
assert len(test_low_images) == len(test_enhanced_images)
|
| 63 |
return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
|
| 62 |
assert len(test_low_images) == len(test_enhanced_images)
|
| 63 |
return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def download_unpaired_low_light_dataset():
|
| 67 |
+
utils.get_file(
|
| 68 |
+
"low_light_dataset.zip",
|
| 69 |
+
"https://github.com/soumik12345/enhance-me/releases/download/v0.3/low_light_dataset.zip",
|
| 70 |
+
cache_dir="./",
|
| 71 |
+
cache_subdir="./datasets",
|
| 72 |
+
extract=True,
|
| 73 |
+
)
|
| 74 |
+
low_images = glob("./datasets/low_light_dataset/*.png")
|
| 75 |
+
test_low_images = sorted(glob("./datasets/low_light_dataset/eval15/low/*"))
|
| 76 |
+
test_enhanced_images = sorted(glob("./datasets/low_light_dataset/eval15/high/*"))
|
| 77 |
+
assert len(test_low_images) == len(test_enhanced_images)
|
| 78 |
+
return low_images, (test_low_images, test_enhanced_images)
|
enhance_me/zero_dce/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .zero_dce import ZeroDCE
|
enhance_me/zero_dce/dataloader.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from ..commons import read_image
|
| 5 |
+
from ..augmentation import UnpairedAugmentationFactory
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class UnpairedLowLightDataset:
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
image_size: int = 256,
|
| 12 |
+
apply_resize: bool = False,
|
| 13 |
+
apply_random_horizontal_flip: bool = True,
|
| 14 |
+
apply_random_vertical_flip: bool = True,
|
| 15 |
+
apply_random_rotation: bool = True,
|
| 16 |
+
) -> None:
|
| 17 |
+
self.augmentation_factory = UnpairedAugmentationFactory(image_size=image_size)
|
| 18 |
+
self.image_size = image_size
|
| 19 |
+
self.apply_resize = apply_resize
|
| 20 |
+
self.apply_random_horizontal_flip = apply_random_horizontal_flip
|
| 21 |
+
self.apply_random_vertical_flip = apply_random_vertical_flip
|
| 22 |
+
self.apply_random_rotation = apply_random_rotation
|
| 23 |
+
|
| 24 |
+
def _resize(self, image):
|
| 25 |
+
return tf.image.resize(image, (self.image_size, self.image_size))
|
| 26 |
+
|
| 27 |
+
def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
|
| 28 |
+
dataset = tf.data.Dataset.from_tensor_slices((images))
|
| 29 |
+
dataset = dataset.map(read_image, num_parallel_calls=tf.data.AUTOTUNE)
|
| 30 |
+
dataset = (
|
| 31 |
+
dataset.map(
|
| 32 |
+
self.augmentation_factory.random_crop,
|
| 33 |
+
num_parallel_calls=tf.data.AUTOTUNE,
|
| 34 |
+
)
|
| 35 |
+
if not self.apply_resize
|
| 36 |
+
else dataset.map(self._resize, num_parallel_calls=tf.data.AUTOTUNE)
|
| 37 |
+
)
|
| 38 |
+
if is_train:
|
| 39 |
+
dataset = (
|
| 40 |
+
dataset.map(
|
| 41 |
+
self.augmentation_factory.random_horizontal_flip,
|
| 42 |
+
num_parallel_calls=tf.data.AUTOTUNE,
|
| 43 |
+
)
|
| 44 |
+
if self.apply_random_horizontal_flip
|
| 45 |
+
else dataset
|
| 46 |
+
)
|
| 47 |
+
dataset = (
|
| 48 |
+
dataset.map(
|
| 49 |
+
self.augmentation_factory.random_vertical_flip,
|
| 50 |
+
num_parallel_calls=tf.data.AUTOTUNE,
|
| 51 |
+
)
|
| 52 |
+
if self.apply_random_vertical_flip
|
| 53 |
+
else dataset
|
| 54 |
+
)
|
| 55 |
+
dataset = (
|
| 56 |
+
dataset.map(
|
| 57 |
+
self.augmentation_factory.random_rotate,
|
| 58 |
+
num_parallel_calls=tf.data.AUTOTUNE,
|
| 59 |
+
)
|
| 60 |
+
if self.apply_random_rotation
|
| 61 |
+
else dataset
|
| 62 |
+
)
|
| 63 |
+
dataset = dataset.batch(batch_size, drop_remainder=True)
|
| 64 |
+
return dataset
|
| 65 |
+
|
| 66 |
+
def get_datasets(
|
| 67 |
+
self,
|
| 68 |
+
images: List[str],
|
| 69 |
+
val_split: float = 0.2,
|
| 70 |
+
batch_size: int = 16,
|
| 71 |
+
):
|
| 72 |
+
split_index = int(len(images) * (1 - val_split))
|
| 73 |
+
train_images = images[:split_index]
|
| 74 |
+
val_images = images[split_index:]
|
| 75 |
+
print(f"Number of train data points: {len(train_images)}")
|
| 76 |
+
print(f"Number of validation data points: {len(val_images)}")
|
| 77 |
+
train_dataset = self._get_dataset(train_images, batch_size, is_train=True)
|
| 78 |
+
val_dataset = self._get_dataset(val_images, batch_size, is_train=False)
|
| 79 |
+
return train_dataset, val_dataset
|
enhance_me/zero_dce/dce_net.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras import layers, Input, Model
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def build_dce_net() -> Model:
|
| 6 |
+
input_image = Input(shape=[None, None, 3])
|
| 7 |
+
conv1 = layers.Conv2D(
|
| 8 |
+
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
| 9 |
+
)(input_image)
|
| 10 |
+
conv2 = layers.Conv2D(
|
| 11 |
+
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
| 12 |
+
)(conv1)
|
| 13 |
+
conv3 = layers.Conv2D(
|
| 14 |
+
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
| 15 |
+
)(conv2)
|
| 16 |
+
conv4 = layers.Conv2D(
|
| 17 |
+
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
| 18 |
+
)(conv3)
|
| 19 |
+
int_con1 = layers.Concatenate(axis=-1)([conv4, conv3])
|
| 20 |
+
conv5 = layers.Conv2D(
|
| 21 |
+
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
| 22 |
+
)(int_con1)
|
| 23 |
+
int_con2 = layers.Concatenate(axis=-1)([conv5, conv2])
|
| 24 |
+
conv6 = layers.Conv2D(
|
| 25 |
+
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
|
| 26 |
+
)(int_con2)
|
| 27 |
+
int_con3 = layers.Concatenate(axis=-1)([conv6, conv1])
|
| 28 |
+
x_r = layers.Conv2D(24, (3, 3), strides=(1, 1), activation="tanh", padding="same")(
|
| 29 |
+
int_con3
|
| 30 |
+
)
|
| 31 |
+
return Model(inputs=input_image, outputs=x_r)
|
enhance_me/zero_dce/losses/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
|
| 3 |
+
from .spatial_constancy import SpatialConsistencyLoss
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def color_constancy_loss(x):
|
| 7 |
+
mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
|
| 8 |
+
mean_r, mean_g, mean_b = (
|
| 9 |
+
mean_rgb[:, :, :, 0],
|
| 10 |
+
mean_rgb[:, :, :, 1],
|
| 11 |
+
mean_rgb[:, :, :, 2],
|
| 12 |
+
)
|
| 13 |
+
diff_rg = tf.square(mean_r - mean_g)
|
| 14 |
+
diff_rb = tf.square(mean_r - mean_b)
|
| 15 |
+
diff_gb = tf.square(mean_b - mean_g)
|
| 16 |
+
return tf.sqrt(tf.square(diff_rg) + tf.square(diff_rb) + tf.square(diff_gb))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def exposure_loss(x, mean_val=0.6):
|
| 20 |
+
x = tf.reduce_mean(x, axis=3, keepdims=True)
|
| 21 |
+
mean = tf.nn.avg_pool2d(x, ksize=16, strides=16, padding="VALID")
|
| 22 |
+
return tf.reduce_mean(tf.square(mean - mean_val))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def illumination_smoothness_loss(x):
|
| 26 |
+
batch_size = tf.shape(x)[0]
|
| 27 |
+
h_x = tf.shape(x)[1]
|
| 28 |
+
w_x = tf.shape(x)[2]
|
| 29 |
+
count_h = (tf.shape(x)[2] - 1) * tf.shape(x)[3]
|
| 30 |
+
count_w = tf.shape(x)[2] * (tf.shape(x)[3] - 1)
|
| 31 |
+
h_tv = tf.reduce_sum(tf.square((x[:, 1:, :, :] - x[:, : h_x - 1, :, :])))
|
| 32 |
+
w_tv = tf.reduce_sum(tf.square((x[:, :, 1:, :] - x[:, :, : w_x - 1, :])))
|
| 33 |
+
batch_size = tf.cast(batch_size, dtype=tf.float32)
|
| 34 |
+
count_h = tf.cast(count_h, dtype=tf.float32)
|
| 35 |
+
count_w = tf.cast(count_w, dtype=tf.float32)
|
| 36 |
+
return 2 * (h_tv / count_h + w_tv / count_w) / batch_size
|
enhance_me/zero_dce/losses/spatial_constancy.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras import losses
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SpatialConsistencyLoss(losses.Loss):
|
| 6 |
+
def __init__(self, **kwargs):
|
| 7 |
+
super(SpatialConsistencyLoss, self).__init__(reduction="none")
|
| 8 |
+
|
| 9 |
+
self.left_kernel = tf.constant(
|
| 10 |
+
[[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
|
| 11 |
+
)
|
| 12 |
+
self.right_kernel = tf.constant(
|
| 13 |
+
[[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32
|
| 14 |
+
)
|
| 15 |
+
self.up_kernel = tf.constant(
|
| 16 |
+
[[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
|
| 17 |
+
)
|
| 18 |
+
self.down_kernel = tf.constant(
|
| 19 |
+
[[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def call(self, y_true, y_pred):
|
| 23 |
+
|
| 24 |
+
original_mean = tf.reduce_mean(y_true, 3, keepdims=True)
|
| 25 |
+
enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True)
|
| 26 |
+
original_pool = tf.nn.avg_pool2d(
|
| 27 |
+
original_mean, ksize=4, strides=4, padding="VALID"
|
| 28 |
+
)
|
| 29 |
+
enhanced_pool = tf.nn.avg_pool2d(
|
| 30 |
+
enhanced_mean, ksize=4, strides=4, padding="VALID"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
d_original_left = tf.nn.conv2d(
|
| 34 |
+
original_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
| 35 |
+
)
|
| 36 |
+
d_original_right = tf.nn.conv2d(
|
| 37 |
+
original_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
| 38 |
+
)
|
| 39 |
+
d_original_up = tf.nn.conv2d(
|
| 40 |
+
original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
| 41 |
+
)
|
| 42 |
+
d_original_down = tf.nn.conv2d(
|
| 43 |
+
original_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
d_enhanced_left = tf.nn.conv2d(
|
| 47 |
+
enhanced_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
| 48 |
+
)
|
| 49 |
+
d_enhanced_right = tf.nn.conv2d(
|
| 50 |
+
enhanced_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
| 51 |
+
)
|
| 52 |
+
d_enhanced_up = tf.nn.conv2d(
|
| 53 |
+
enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
| 54 |
+
)
|
| 55 |
+
d_enhanced_down = tf.nn.conv2d(
|
| 56 |
+
enhanced_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
d_left = tf.square(d_original_left - d_enhanced_left)
|
| 60 |
+
d_right = tf.square(d_original_right - d_enhanced_right)
|
| 61 |
+
d_up = tf.square(d_original_up - d_enhanced_up)
|
| 62 |
+
d_down = tf.square(d_original_down - d_enhanced_down)
|
| 63 |
+
return d_left + d_right + d_up + d_down
|
enhance_me/zero_dce/zero_dce.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
from tensorflow import keras
|
| 8 |
+
from tensorflow.keras import optimizers, mixed_precision, Model
|
| 9 |
+
from wandb.keras import WandbCallback
|
| 10 |
+
|
| 11 |
+
from .dce_net import build_dce_net
|
| 12 |
+
from .dataloader import UnpairedLowLightDataset
|
| 13 |
+
from .losses import (
|
| 14 |
+
color_constancy_loss,
|
| 15 |
+
exposure_loss,
|
| 16 |
+
illumination_smoothness_loss,
|
| 17 |
+
SpatialConsistencyLoss,
|
| 18 |
+
)
|
| 19 |
+
from ..commons import (
|
| 20 |
+
download_lol_dataset,
|
| 21 |
+
download_unpaired_low_light_dataset,
|
| 22 |
+
init_wandb,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ZeroDCE(Model):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
experiment_name=None,
|
| 30 |
+
wandb_api_key=None,
|
| 31 |
+
use_mixed_precision: bool = False,
|
| 32 |
+
**kwargs
|
| 33 |
+
):
|
| 34 |
+
super(ZeroDCE, self).__init__(**kwargs)
|
| 35 |
+
self.experiment_name = experiment_name
|
| 36 |
+
if use_mixed_precision:
|
| 37 |
+
policy = mixed_precision.Policy("mixed_float16")
|
| 38 |
+
mixed_precision.set_global_policy(policy)
|
| 39 |
+
if wandb_api_key is not None:
|
| 40 |
+
init_wandb("zero-dce", experiment_name, wandb_api_key)
|
| 41 |
+
self.using_wandb = True
|
| 42 |
+
else:
|
| 43 |
+
self.using_wandb = False
|
| 44 |
+
self.dce_model = build_dce_net()
|
| 45 |
+
|
| 46 |
+
def compile(self, learning_rate, **kwargs):
|
| 47 |
+
super(ZeroDCE, self).compile(**kwargs)
|
| 48 |
+
self.optimizer = optimizers.Adam(learning_rate=learning_rate)
|
| 49 |
+
self.spatial_constancy_loss = SpatialConsistencyLoss(reduction="none")
|
| 50 |
+
|
| 51 |
+
def get_enhanced_image(self, data, output):
|
| 52 |
+
r1 = output[:, :, :, :3]
|
| 53 |
+
r2 = output[:, :, :, 3:6]
|
| 54 |
+
r3 = output[:, :, :, 6:9]
|
| 55 |
+
r4 = output[:, :, :, 9:12]
|
| 56 |
+
r5 = output[:, :, :, 12:15]
|
| 57 |
+
r6 = output[:, :, :, 15:18]
|
| 58 |
+
r7 = output[:, :, :, 18:21]
|
| 59 |
+
r8 = output[:, :, :, 21:24]
|
| 60 |
+
x = data + r1 * (tf.square(data) - data)
|
| 61 |
+
x = x + r2 * (tf.square(x) - x)
|
| 62 |
+
x = x + r3 * (tf.square(x) - x)
|
| 63 |
+
enhanced_image = x + r4 * (tf.square(x) - x)
|
| 64 |
+
x = enhanced_image + r5 * (tf.square(enhanced_image) - enhanced_image)
|
| 65 |
+
x = x + r6 * (tf.square(x) - x)
|
| 66 |
+
x = x + r7 * (tf.square(x) - x)
|
| 67 |
+
enhanced_image = x + r8 * (tf.square(x) - x)
|
| 68 |
+
return enhanced_image
|
| 69 |
+
|
| 70 |
+
def call(self, data):
|
| 71 |
+
dce_net_output = self.dce_model(data)
|
| 72 |
+
return self.get_enhanced_image(data, dce_net_output)
|
| 73 |
+
|
| 74 |
+
def compute_losses(self, data, output):
|
| 75 |
+
enhanced_image = self.get_enhanced_image(data, output)
|
| 76 |
+
loss_illumination = 200 * illumination_smoothness_loss(output)
|
| 77 |
+
loss_spatial_constancy = tf.reduce_mean(
|
| 78 |
+
self.spatial_constancy_loss(enhanced_image, data)
|
| 79 |
+
)
|
| 80 |
+
loss_color_constancy = 5 * tf.reduce_mean(color_constancy_loss(enhanced_image))
|
| 81 |
+
loss_exposure = 10 * tf.reduce_mean(exposure_loss(enhanced_image))
|
| 82 |
+
total_loss = (
|
| 83 |
+
loss_illumination
|
| 84 |
+
+ loss_spatial_constancy
|
| 85 |
+
+ loss_color_constancy
|
| 86 |
+
+ loss_exposure
|
| 87 |
+
)
|
| 88 |
+
return {
|
| 89 |
+
"total_loss": total_loss,
|
| 90 |
+
"illumination_smoothness_loss": loss_illumination,
|
| 91 |
+
"spatial_constancy_loss": loss_spatial_constancy,
|
| 92 |
+
"color_constancy_loss": loss_color_constancy,
|
| 93 |
+
"exposure_loss": loss_exposure,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
def train_step(self, data):
|
| 97 |
+
with tf.GradientTape() as tape:
|
| 98 |
+
output = self.dce_model(data)
|
| 99 |
+
losses = self.compute_losses(data, output)
|
| 100 |
+
gradients = tape.gradient(
|
| 101 |
+
losses["total_loss"], self.dce_model.trainable_weights
|
| 102 |
+
)
|
| 103 |
+
self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights))
|
| 104 |
+
return losses
|
| 105 |
+
|
| 106 |
+
def test_step(self, data):
|
| 107 |
+
output = self.dce_model(data)
|
| 108 |
+
return self.compute_losses(data, output)
|
| 109 |
+
|
| 110 |
+
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
|
| 111 |
+
"""While saving the weights, we simply save the weights of the DCE-Net"""
|
| 112 |
+
self.dce_model.save_weights(
|
| 113 |
+
filepath, overwrite=overwrite, save_format=save_format, options=options
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
|
| 117 |
+
"""While loading the weights, we simply load the weights of the DCE-Net"""
|
| 118 |
+
self.dce_model.load_weights(
|
| 119 |
+
filepath=filepath,
|
| 120 |
+
by_name=by_name,
|
| 121 |
+
skip_mismatch=skip_mismatch,
|
| 122 |
+
options=options,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def build_datasets(
|
| 126 |
+
self,
|
| 127 |
+
image_size: int = 256,
|
| 128 |
+
dataset_label: str = "lol",
|
| 129 |
+
apply_resize: bool = False,
|
| 130 |
+
apply_random_horizontal_flip: bool = True,
|
| 131 |
+
apply_random_vertical_flip: bool = True,
|
| 132 |
+
apply_random_rotation: bool = True,
|
| 133 |
+
val_split: float = 0.2,
|
| 134 |
+
batch_size: int = 16,
|
| 135 |
+
) -> None:
|
| 136 |
+
if dataset_label == "lol":
|
| 137 |
+
(self.low_images, _), (self.test_low_images, _) = download_lol_dataset()
|
| 138 |
+
elif dataset_label == "unpaired":
|
| 139 |
+
self.low_images, (
|
| 140 |
+
self.test_low_images,
|
| 141 |
+
_,
|
| 142 |
+
) = download_unpaired_low_light_dataset()
|
| 143 |
+
data_loader = UnpairedLowLightDataset(
|
| 144 |
+
image_size,
|
| 145 |
+
apply_resize,
|
| 146 |
+
apply_random_horizontal_flip,
|
| 147 |
+
apply_random_vertical_flip,
|
| 148 |
+
apply_random_rotation,
|
| 149 |
+
)
|
| 150 |
+
self.train_dataset, self.val_dataset = data_loader.get_datasets(
|
| 151 |
+
self.low_images, val_split, batch_size
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def train(self, epochs: int):
|
| 155 |
+
log_dir = os.path.join(
|
| 156 |
+
self.experiment_name,
|
| 157 |
+
"logs",
|
| 158 |
+
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
| 159 |
+
)
|
| 160 |
+
tensorboard_callback = keras.callbacks.TensorBoard(log_dir, histogram_freq=1)
|
| 161 |
+
callbacks = [tensorboard_callback]
|
| 162 |
+
if self.using_wandb:
|
| 163 |
+
callbacks += [WandbCallback()]
|
| 164 |
+
history = self.fit(
|
| 165 |
+
self.train_dataset,
|
| 166 |
+
validation_data=self.val_dataset,
|
| 167 |
+
epochs=epochs,
|
| 168 |
+
callbacks=callbacks,
|
| 169 |
+
)
|
| 170 |
+
return history
|
| 171 |
+
|
| 172 |
+
def infer(self, original_image):
|
| 173 |
+
image = keras.preprocessing.image.img_to_array(original_image)
|
| 174 |
+
image = image.astype("float32") / 255.0
|
| 175 |
+
image = np.expand_dims(image, axis=0)
|
| 176 |
+
output_image = self.call(image)
|
| 177 |
+
output_image = tf.cast((output_image[0, :, :, :] * 255), dtype=np.uint8)
|
| 178 |
+
output_image = Image.fromarray(output_image.numpy())
|
| 179 |
+
return output_image
|
| 180 |
+
|
| 181 |
+
def infer_from_file(self, original_image_file: str):
|
| 182 |
+
original_image = Image.open(original_image_file)
|
| 183 |
+
return self.infer(original_image)
|
notebooks/enhance_me_train.ipynb
CHANGED
|
@@ -37,11 +37,12 @@
|
|
| 37 |
"import os\n",
|
| 38 |
"import sys\n",
|
| 39 |
"\n",
|
| 40 |
-
"sys.path.append(\"
|
| 41 |
"\n",
|
| 42 |
"from PIL import Image\n",
|
| 43 |
"from enhance_me import commons\n",
|
| 44 |
-
"from enhance_me.mirnet import MIRNet"
|
|
|
|
| 45 |
]
|
| 46 |
},
|
| 47 |
{
|
|
@@ -170,7 +171,7 @@
|
|
| 170 |
" enhanced_image = mirnet.infer(original_image)\n",
|
| 171 |
" ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
|
| 172 |
" commons.plot_results(\n",
|
| 173 |
-
" [original_image, ground_truth,
|
| 174 |
" [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
|
| 175 |
" (18, 18),\n",
|
| 176 |
" )"
|
|
@@ -183,6 +184,92 @@
|
|
| 183 |
"id": "dO-IbNQHkB3R"
|
| 184 |
},
|
| 185 |
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
"source": []
|
| 187 |
}
|
| 188 |
],
|
|
@@ -197,13 +284,23 @@
|
|
| 197 |
"provenance": []
|
| 198 |
},
|
| 199 |
"kernelspec": {
|
| 200 |
-
"display_name": "Python 3",
|
|
|
|
| 201 |
"name": "python3"
|
| 202 |
},
|
| 203 |
"language_info": {
|
| 204 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
}
|
| 206 |
},
|
| 207 |
"nbformat": 4,
|
| 208 |
-
"nbformat_minor":
|
| 209 |
}
|
|
|
|
| 37 |
"import os\n",
|
| 38 |
"import sys\n",
|
| 39 |
"\n",
|
| 40 |
+
"sys.path.append(\"..\")\n",
|
| 41 |
"\n",
|
| 42 |
"from PIL import Image\n",
|
| 43 |
"from enhance_me import commons\n",
|
| 44 |
+
"from enhance_me.mirnet import MIRNet\n",
|
| 45 |
+
"from enhance_me.zero_dce import ZeroDCE"
|
| 46 |
]
|
| 47 |
},
|
| 48 |
{
|
|
|
|
| 171 |
" enhanced_image = mirnet.infer(original_image)\n",
|
| 172 |
" ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
|
| 173 |
" commons.plot_results(\n",
|
| 174 |
+
" [original_image, ground_truth, enhanced_image],\n",
|
| 175 |
" [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
|
| 176 |
" (18, 18),\n",
|
| 177 |
" )"
|
|
|
|
| 184 |
"id": "dO-IbNQHkB3R"
|
| 185 |
},
|
| 186 |
"outputs": [],
|
| 187 |
+
"source": [
|
| 188 |
+
"# @title Zero-DCE Train Configs\n",
|
| 189 |
+
"\n",
|
| 190 |
+
"experiment_name = \"unpaired_low_light_256_resize\" # @param {type:\"string\"}\n",
|
| 191 |
+
"image_size = 256 # @param {type:\"integer\"}\n",
|
| 192 |
+
"dataset_label = \"unpaired\" # @param [\"lol\", \"unpaired\"]\n",
|
| 193 |
+
"use_mixed_precision = False # @param {type:\"boolean\"}\n",
|
| 194 |
+
"apply_resize = True # @param {type:\"boolean\"}\n",
|
| 195 |
+
"apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n",
|
| 196 |
+
"apply_random_vertical_flip = True # @param {type:\"boolean\"}\n",
|
| 197 |
+
"apply_random_rotation = True # @param {type:\"boolean\"}\n",
|
| 198 |
+
"wandb_api_key = \"\" # @param {type:\"string\"}\n",
|
| 199 |
+
"val_split = 0.1 # @param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n",
|
| 200 |
+
"batch_size = 16 # @param {type:\"integer\"}\n",
|
| 201 |
+
"learning_rate = 1e-4 # @param {type:\"number\"}\n",
|
| 202 |
+
"epsilon = 1e-3 # @param {type:\"number\"}\n",
|
| 203 |
+
"epochs = 100 # @param {type:\"slider\", min:10, max:100, step:5}"
|
| 204 |
+
]
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"cell_type": "code",
|
| 208 |
+
"execution_count": null,
|
| 209 |
+
"metadata": {},
|
| 210 |
+
"outputs": [],
|
| 211 |
+
"source": [
|
| 212 |
+
"zero_dce = ZeroDCE(\n",
|
| 213 |
+
" experiment_name=experiment_name,\n",
|
| 214 |
+
" wandb_api_key=None if wandb_api_key == \"\" else wandb_api_key,\n",
|
| 215 |
+
" use_mixed_precision=use_mixed_precision\n",
|
| 216 |
+
")"
|
| 217 |
+
]
|
| 218 |
+
},
|
| 219 |
+
{
|
| 220 |
+
"cell_type": "code",
|
| 221 |
+
"execution_count": null,
|
| 222 |
+
"metadata": {},
|
| 223 |
+
"outputs": [],
|
| 224 |
+
"source": [
|
| 225 |
+
"zero_dce.build_datasets(\n",
|
| 226 |
+
" image_size=image_size,\n",
|
| 227 |
+
" dataset_label=dataset_label,\n",
|
| 228 |
+
" apply_resize=apply_resize,\n",
|
| 229 |
+
" apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
|
| 230 |
+
" apply_random_vertical_flip=apply_random_vertical_flip,\n",
|
| 231 |
+
" apply_random_rotation=apply_random_rotation,\n",
|
| 232 |
+
" val_split=val_split,\n",
|
| 233 |
+
" batch_size=batch_size\n",
|
| 234 |
+
")"
|
| 235 |
+
]
|
| 236 |
+
},
|
| 237 |
+
{
|
| 238 |
+
"cell_type": "code",
|
| 239 |
+
"execution_count": null,
|
| 240 |
+
"metadata": {
|
| 241 |
+
"scrolled": false
|
| 242 |
+
},
|
| 243 |
+
"outputs": [],
|
| 244 |
+
"source": [
|
| 245 |
+
"zero_dce.compile(learning_rate=learning_rate)\n",
|
| 246 |
+
"history = zero_dce.train(epochs=epochs)\n",
|
| 247 |
+
"zero_dce.save_weights(os.path.join(experiment_name, \"weights.h5\"))"
|
| 248 |
+
]
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"cell_type": "code",
|
| 252 |
+
"execution_count": null,
|
| 253 |
+
"metadata": {
|
| 254 |
+
"scrolled": false
|
| 255 |
+
},
|
| 256 |
+
"outputs": [],
|
| 257 |
+
"source": [
|
| 258 |
+
"for index, low_image_file in enumerate(zero_dce.test_low_images):\n",
|
| 259 |
+
" original_image = Image.open(low_image_file)\n",
|
| 260 |
+
" enhanced_image = zero_dce.infer(original_image)\n",
|
| 261 |
+
" commons.plot_results(\n",
|
| 262 |
+
" [original_image, enhanced_image],\n",
|
| 263 |
+
" [\"Original Image\", \"Enhanced Image\"],\n",
|
| 264 |
+
" (18, 18),\n",
|
| 265 |
+
" )"
|
| 266 |
+
]
|
| 267 |
+
},
|
| 268 |
+
{
|
| 269 |
+
"cell_type": "code",
|
| 270 |
+
"execution_count": null,
|
| 271 |
+
"metadata": {},
|
| 272 |
+
"outputs": [],
|
| 273 |
"source": []
|
| 274 |
}
|
| 275 |
],
|
|
|
|
| 284 |
"provenance": []
|
| 285 |
},
|
| 286 |
"kernelspec": {
|
| 287 |
+
"display_name": "Python 3 (ipykernel)",
|
| 288 |
+
"language": "python",
|
| 289 |
"name": "python3"
|
| 290 |
},
|
| 291 |
"language_info": {
|
| 292 |
+
"codemirror_mode": {
|
| 293 |
+
"name": "ipython",
|
| 294 |
+
"version": 3
|
| 295 |
+
},
|
| 296 |
+
"file_extension": ".py",
|
| 297 |
+
"mimetype": "text/x-python",
|
| 298 |
+
"name": "python",
|
| 299 |
+
"nbconvert_exporter": "python",
|
| 300 |
+
"pygments_lexer": "ipython3",
|
| 301 |
+
"version": "3.8.10"
|
| 302 |
}
|
| 303 |
},
|
| 304 |
"nbformat": 4,
|
| 305 |
+
"nbformat_minor": 1
|
| 306 |
}
|
test.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enhance_me.commons import download_unpaired_low_light_dataset
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
download_unpaired_low_light_dataset()
|