Spaces:
Runtime error
Runtime error
Commit
·
6fd61b9
1
Parent(s):
865788c
added mirnet model + charbonnier loss
Browse files- enhance_me/commons.py +4 -0
- enhance_me/mirnet/losses.py +13 -0
- enhance_me/mirnet/models/__init__.py +1 -0
- enhance_me/mirnet/models/dual_attention.py +41 -0
- enhance_me/mirnet/models/mirnet_model.py +14 -0
- enhance_me/mirnet/models/recursive_residual_blocks.py +79 -0
- enhance_me/mirnet/models/skff.py +30 -0
enhance_me/commons.py
CHANGED
|
@@ -7,3 +7,7 @@ def read_image(image_path):
|
|
| 7 |
image.set_shape([None, None, 3])
|
| 8 |
image = tf.cast(image, dtype=tf.float32) / 255.0
|
| 9 |
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
image.set_shape([None, None, 3])
|
| 8 |
image = tf.cast(image, dtype=tf.float32) / 255.0
|
| 9 |
return image
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def peak_signal_noise_ratio(y_true, y_pred):
|
| 13 |
+
return tf.image.psnr(y_pred, y_true, max_val=255.0)
|
enhance_me/mirnet/losses.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras import losses
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class CharbonnierLoss(losses.Loss):
|
| 6 |
+
def __init__(self, epsilon: float = 1e-3, *args, **kwargs):
|
| 7 |
+
super().__init__(*args, **kwargs)
|
| 8 |
+
self.epsilon = epsilon
|
| 9 |
+
|
| 10 |
+
def call(self, y_true, y_pred):
|
| 11 |
+
return tf.reduce_mean(
|
| 12 |
+
tf.sqrt(tf.square(y_true - y_pred) + tf.square(self.epsilon))
|
| 13 |
+
)
|
enhance_me/mirnet/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .mirnet_model import build_mirnet_model
|
enhance_me/mirnet/models/dual_attention.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras import layers
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def spatial_attention_block(input_tensor):
|
| 6 |
+
average_pooling = tf.reduce_max(input_tensor, axis=-1)
|
| 7 |
+
average_pooling = tf.expand_dims(average_pooling, axis=-1)
|
| 8 |
+
max_pooling = tf.reduce_mean(input_tensor, axis=-1)
|
| 9 |
+
max_pooling = tf.expand_dims(max_pooling, axis=-1)
|
| 10 |
+
concatenated = layers.Concatenate(axis=-1)([average_pooling, max_pooling])
|
| 11 |
+
feature_map = layers.Conv2D(1, kernel_size=(1, 1))(concatenated)
|
| 12 |
+
feature_map = tf.nn.sigmoid(feature_map)
|
| 13 |
+
return input_tensor * feature_map
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def channel_attention_block(input_tensor):
|
| 17 |
+
channels = list(input_tensor.shape)[-1]
|
| 18 |
+
average_pooling = layers.GlobalAveragePooling2D()(input_tensor)
|
| 19 |
+
feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels))
|
| 20 |
+
feature_activations = layers.Conv2D(
|
| 21 |
+
filters=channels // 8, kernel_size=(1, 1), activation="relu"
|
| 22 |
+
)(feature_descriptor)
|
| 23 |
+
feature_activations = layers.Conv2D(
|
| 24 |
+
filters=channels, kernel_size=(1, 1), activation="sigmoid"
|
| 25 |
+
)(feature_activations)
|
| 26 |
+
return input_tensor * feature_activations
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def dual_attention_unit_block(input_tensor):
|
| 30 |
+
channels = list(input_tensor.shape)[-1]
|
| 31 |
+
feature_map = layers.Conv2D(
|
| 32 |
+
channels, kernel_size=(3, 3), padding="same", activation="relu"
|
| 33 |
+
)(input_tensor)
|
| 34 |
+
feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(
|
| 35 |
+
feature_map
|
| 36 |
+
)
|
| 37 |
+
channel_attention = channel_attention_block(feature_map)
|
| 38 |
+
spatial_attention = spatial_attention_block(feature_map)
|
| 39 |
+
concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention])
|
| 40 |
+
concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation)
|
| 41 |
+
return layers.Add()([input_tensor, concatenation])
|
enhance_me/mirnet/models/mirnet_model.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras import layers, Input, Model
|
| 3 |
+
|
| 4 |
+
from .recursive_residual_blocks import recursive_residual_group
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def mirnet_model(num_rrg, num_mrb, channels):
|
| 8 |
+
input_tensor = Input(shape=[None, None, 3])
|
| 9 |
+
x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
|
| 10 |
+
for _ in range(num_rrg):
|
| 11 |
+
x1 = recursive_residual_group(x1, num_mrb, channels)
|
| 12 |
+
conv = layers.Conv2D(3, kernel_size=(3, 3), padding="same")(x1)
|
| 13 |
+
output_tensor = layers.Add()([input_tensor, conv])
|
| 14 |
+
return Model(input_tensor, output_tensor)
|
enhance_me/mirnet/models/recursive_residual_blocks.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras import layers
|
| 3 |
+
|
| 4 |
+
from .skff import selective_kernel_feature_fusion
|
| 5 |
+
from .dual_attention import dual_attention_unit_block
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def down_sampling_module(input_tensor):
|
| 9 |
+
channels = list(input_tensor.shape)[-1]
|
| 10 |
+
main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
|
| 11 |
+
input_tensor
|
| 12 |
+
)
|
| 13 |
+
main_branch = layers.Conv2D(
|
| 14 |
+
channels, kernel_size=(3, 3), padding="same", activation="relu"
|
| 15 |
+
)(main_branch)
|
| 16 |
+
main_branch = layers.MaxPooling2D()(main_branch)
|
| 17 |
+
main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch)
|
| 18 |
+
skip_branch = layers.MaxPooling2D()(input_tensor)
|
| 19 |
+
skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch)
|
| 20 |
+
return layers.Add()([skip_branch, main_branch])
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def up_sampling_module(input_tensor):
|
| 24 |
+
channels = list(input_tensor.shape)[-1]
|
| 25 |
+
main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
|
| 26 |
+
input_tensor
|
| 27 |
+
)
|
| 28 |
+
main_branch = layers.Conv2D(
|
| 29 |
+
channels, kernel_size=(3, 3), padding="same", activation="relu"
|
| 30 |
+
)(main_branch)
|
| 31 |
+
main_branch = layers.UpSampling2D()(main_branch)
|
| 32 |
+
main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch)
|
| 33 |
+
skip_branch = layers.UpSampling2D()(input_tensor)
|
| 34 |
+
skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch)
|
| 35 |
+
return layers.Add()([skip_branch, main_branch])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# MRB Block
|
| 39 |
+
def multi_scale_residual_block(input_tensor, channels):
|
| 40 |
+
# features
|
| 41 |
+
level1 = input_tensor
|
| 42 |
+
level2 = down_sampling_module(input_tensor)
|
| 43 |
+
level3 = down_sampling_module(level2)
|
| 44 |
+
# DAU
|
| 45 |
+
level1_dau = dual_attention_unit_block(level1)
|
| 46 |
+
level2_dau = dual_attention_unit_block(level2)
|
| 47 |
+
level3_dau = dual_attention_unit_block(level3)
|
| 48 |
+
# SKFF
|
| 49 |
+
level1_skff = selective_kernel_feature_fusion(
|
| 50 |
+
level1_dau,
|
| 51 |
+
up_sampling_module(level2_dau),
|
| 52 |
+
up_sampling_module(up_sampling_module(level3_dau)),
|
| 53 |
+
)
|
| 54 |
+
level2_skff = selective_kernel_feature_fusion(
|
| 55 |
+
down_sampling_module(level1_dau), level2_dau, up_sampling_module(level3_dau)
|
| 56 |
+
)
|
| 57 |
+
level3_skff = selective_kernel_feature_fusion(
|
| 58 |
+
down_sampling_module(down_sampling_module(level1_dau)),
|
| 59 |
+
down_sampling_module(level2_dau),
|
| 60 |
+
level3_dau,
|
| 61 |
+
)
|
| 62 |
+
# DAU 2
|
| 63 |
+
level1_dau_2 = dual_attention_unit_block(level1_skff)
|
| 64 |
+
level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff)))
|
| 65 |
+
level3_dau_2 = up_sampling_module(
|
| 66 |
+
up_sampling_module(dual_attention_unit_block(level3_skff))
|
| 67 |
+
)
|
| 68 |
+
# SKFF 2
|
| 69 |
+
skff_ = selective_kernel_feature_fusion(level1_dau_2, level3_dau_2, level3_dau_2)
|
| 70 |
+
conv = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(skff_)
|
| 71 |
+
return layers.Add()([input_tensor, conv])
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def recursive_residual_group(input_tensor, num_mrb, channels):
|
| 75 |
+
conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
|
| 76 |
+
for _ in range(num_mrb):
|
| 77 |
+
conv1 = multi_scale_residual_block(conv1, channels)
|
| 78 |
+
conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1)
|
| 79 |
+
return layers.Add()([conv2, input_tensor])
|
enhance_me/mirnet/models/skff.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras import layers
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def selective_kernel_feature_fusion(
|
| 6 |
+
multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3
|
| 7 |
+
):
|
| 8 |
+
channels = list(multi_scale_feature_1.shape)[-1]
|
| 9 |
+
combined_feature = layers.Add()(
|
| 10 |
+
[multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3]
|
| 11 |
+
)
|
| 12 |
+
gap = layers.GlobalAveragePooling2D()(combined_feature)
|
| 13 |
+
channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels))
|
| 14 |
+
compact_feature_representation = layers.Conv2D(
|
| 15 |
+
filters=channels // 8, kernel_size=(1, 1), activation="relu"
|
| 16 |
+
)(channel_wise_statistics)
|
| 17 |
+
feature_descriptor_1 = layers.Conv2D(
|
| 18 |
+
channels, kernel_size=(1, 1), activation="softmax"
|
| 19 |
+
)(compact_feature_representation)
|
| 20 |
+
feature_descriptor_2 = layers.Conv2D(
|
| 21 |
+
channels, kernel_size=(1, 1), activation="softmax"
|
| 22 |
+
)(compact_feature_representation)
|
| 23 |
+
feature_descriptor_3 = layers.Conv2D(
|
| 24 |
+
channels, kernel_size=(1, 1), activation="softmax"
|
| 25 |
+
)(compact_feature_representation)
|
| 26 |
+
feature_1 = multi_scale_feature_1 * feature_descriptor_1
|
| 27 |
+
feature_2 = multi_scale_feature_2 * feature_descriptor_2
|
| 28 |
+
feature_3 = multi_scale_feature_3 * feature_descriptor_3
|
| 29 |
+
aggregated_feature = layers.Add()([feature_1, feature_2, feature_3])
|
| 30 |
+
return aggregated_feature
|