Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
import kornia | |
import numpy as np | |
def compute_mean_std(tensor, mask=None): | |
if mask is not None: | |
# Apply mask to the tensor | |
masked_tensor = tensor * mask | |
mask_sum = mask.sum(dim=[2, 3], keepdim=True) | |
mask_sum = torch.clamp(mask_sum, min=1e-6) | |
mean = torch.nan_to_num(masked_tensor.sum(dim=[2, 3], keepdim=True) / mask_sum) | |
std = torch.sqrt(torch.nan_to_num(((masked_tensor - mean) ** 2 * mask).sum(dim=[2, 3], keepdim=True) / mask_sum)) | |
else: | |
mean = tensor.mean(dim=[2, 3], keepdim=True) | |
std = tensor.std(dim=[2, 3], keepdim=True) | |
return mean, std | |
def apply_color_match(image, reference, color_space, factor, device='cpu'): | |
if image is None or reference is None: | |
return None | |
# Convert to torch tensors and normalize | |
image = torch.from_numpy(image).float() / 255.0 | |
reference = torch.from_numpy(reference).float() / 255.0 | |
# Add batch dimension and rearrange to BCHW | |
image = image.unsqueeze(0).permute(0, 3, 1, 2) | |
reference = reference.unsqueeze(0).permute(0, 3, 1, 2) | |
# Convert to target color space | |
if color_space == "LAB": | |
image_conv = kornia.color.rgb_to_lab(image) | |
reference_conv = kornia.color.rgb_to_lab(reference) | |
back_conversion = kornia.color.lab_to_rgb | |
elif color_space == "YCbCr": | |
image_conv = kornia.color.rgb_to_ycbcr(image) | |
reference_conv = kornia.color.rgb_to_ycbcr(reference) | |
back_conversion = kornia.color.ycbcr_to_rgb | |
elif color_space == "LUV": | |
image_conv = kornia.color.rgb_to_luv(image) | |
reference_conv = kornia.color.rgb_to_luv(reference) | |
back_conversion = kornia.color.luv_to_rgb | |
elif color_space == "YUV": | |
image_conv = kornia.color.rgb_to_yuv(image) | |
reference_conv = kornia.color.rgb_to_yuv(reference) | |
back_conversion = kornia.color.yuv_to_rgb | |
elif color_space == "XYZ": | |
image_conv = kornia.color.rgb_to_xyz(image) | |
reference_conv = kornia.color.rgb_to_xyz(reference) | |
back_conversion = kornia.color.xyz_to_rgb | |
else: # RGB | |
image_conv = image | |
reference_conv = reference | |
back_conversion = lambda x: x | |
# Compute statistics | |
reference_mean, reference_std = compute_mean_std(reference_conv) | |
image_mean, image_std = compute_mean_std(image_conv) | |
# Apply color matching | |
matched = torch.nan_to_num((image_conv - image_mean) / image_std) * reference_std + reference_mean | |
matched = factor * matched + (1 - factor) * image_conv | |
# Convert back to RGB | |
matched = back_conversion(matched) | |
# Convert back to HWC format and to uint8 | |
output = matched.squeeze(0).permute(1, 2, 0) | |
output = (output.clamp(0, 1).numpy() * 255).astype(np.uint8) | |
return output | |
def analyze_color_statistics(image): | |
l, a, b = kornia.color.rgb_to_lab(image).chunk(3, dim=1) | |
mean_l = l.mean() | |
std_l = l.std() | |
mean_a = a.mean() | |
mean_b = b.mean() | |
std_ab = torch.sqrt(a.var() + b.var()) | |
return mean_l, std_l, mean_a, mean_b, std_ab | |
def apply_adobe_color_match(image, reference, color_space, luminance_factor, color_intensity_factor, fade_factor, neutralization_factor): | |
if image is None or reference is None: | |
return None | |
# Convert to torch tensors and normalize | |
image = torch.from_numpy(image).float() / 255.0 | |
reference = torch.from_numpy(reference).float() / 255.0 | |
# Add batch dimension and rearrange to BCHW | |
image = image.unsqueeze(0).permute(0, 3, 1, 2) | |
reference = reference.unsqueeze(0).permute(0, 3, 1, 2) | |
# Analyze color statistics | |
source_stats = analyze_color_statistics(reference) | |
dest_stats = analyze_color_statistics(image) | |
# Convert to LAB | |
l, a, b = kornia.color.rgb_to_lab(image).chunk(3, dim=1) | |
# Unpack statistics | |
src_mean_l, src_std_l, src_mean_a, src_mean_b, src_std_ab = source_stats | |
dest_mean_l, dest_std_l, dest_mean_a, dest_mean_b, dest_std_ab = dest_stats | |
# Apply transformations | |
l_new = (l - dest_mean_l) * (src_std_l / dest_std_l) * luminance_factor + src_mean_l | |
# Neutralize color cast | |
a = a - neutralization_factor * dest_mean_a | |
b = b - neutralization_factor * dest_mean_b | |
# Adjust color intensity | |
a_new = a * (src_std_ab / dest_std_ab) * color_intensity_factor | |
b_new = b * (src_std_ab / dest_std_ab) * color_intensity_factor | |
# Combine channels | |
lab_new = torch.cat([l_new, a_new, b_new], dim=1) | |
# Convert back to RGB | |
rgb_new = kornia.color.lab_to_rgb(lab_new) | |
# Apply fade factor | |
result = fade_factor * rgb_new + (1 - fade_factor) * image | |
# Convert back to HWC format and to uint8 | |
output = result.squeeze(0).permute(1, 2, 0) | |
output = (output.clamp(0, 1).numpy() * 255).astype(np.uint8) | |
return output | |
def create_color_match_tab(): | |
with gr.Tab("Color Matching"): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", height=256) | |
reference_image = gr.Image(label="Reference Image", height=256) | |
with gr.Tabs(): | |
with gr.Tab("Standard"): | |
color_space = gr.Dropdown( | |
choices=["LAB", "YCbCr", "RGB", "LUV", "YUV", "XYZ"], | |
value="LAB", | |
label="Color Space" | |
) | |
factor = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=1.0, | |
step=0.05, | |
label="Factor" | |
) | |
standard_btn = gr.Button("Apply Standard Color Match") | |
with gr.Tab("Adobe Style"): | |
adobe_color_space = gr.Dropdown( | |
choices=["RGB", "LAB"], | |
value="LAB", | |
label="Color Space" | |
) | |
luminance_factor = gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
value=1.0, | |
step=0.05, | |
label="Luminance Factor" | |
) | |
color_intensity_factor = gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
value=1.0, | |
step=0.05, | |
label="Color Intensity Factor" | |
) | |
fade_factor = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=1.0, | |
step=0.05, | |
label="Fade Factor" | |
) | |
neutralization_factor = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.0, | |
step=0.05, | |
label="Neutralization Factor" | |
) | |
adobe_btn = gr.Button("Apply Adobe Style Color Match") | |
with gr.Column(): | |
output_image = gr.Image(label="Color Matched Image") | |
standard_btn.click( | |
fn=apply_color_match, | |
inputs=[input_image, reference_image, color_space, factor], | |
outputs=output_image | |
) | |
adobe_btn.click( | |
fn=apply_adobe_color_match, | |
inputs=[ | |
input_image, reference_image, adobe_color_space, | |
luminance_factor, color_intensity_factor, fade_factor, neutralization_factor | |
], | |
outputs=output_image | |
) |