TRELLIS-Single3D / color_match_processor.py
gokaygokay's picture
add tab
215c5da
raw
history blame
8.13 kB
import gradio as gr
import torch
import torch.nn.functional as F
import kornia
import numpy as np
def compute_mean_std(tensor, mask=None):
if mask is not None:
# Apply mask to the tensor
masked_tensor = tensor * mask
mask_sum = mask.sum(dim=[2, 3], keepdim=True)
mask_sum = torch.clamp(mask_sum, min=1e-6)
mean = torch.nan_to_num(masked_tensor.sum(dim=[2, 3], keepdim=True) / mask_sum)
std = torch.sqrt(torch.nan_to_num(((masked_tensor - mean) ** 2 * mask).sum(dim=[2, 3], keepdim=True) / mask_sum))
else:
mean = tensor.mean(dim=[2, 3], keepdim=True)
std = tensor.std(dim=[2, 3], keepdim=True)
return mean, std
def apply_color_match(image, reference, color_space, factor, device='cpu'):
if image is None or reference is None:
return None
# Convert to torch tensors and normalize
image = torch.from_numpy(image).float() / 255.0
reference = torch.from_numpy(reference).float() / 255.0
# Add batch dimension and rearrange to BCHW
image = image.unsqueeze(0).permute(0, 3, 1, 2)
reference = reference.unsqueeze(0).permute(0, 3, 1, 2)
# Convert to target color space
if color_space == "LAB":
image_conv = kornia.color.rgb_to_lab(image)
reference_conv = kornia.color.rgb_to_lab(reference)
back_conversion = kornia.color.lab_to_rgb
elif color_space == "YCbCr":
image_conv = kornia.color.rgb_to_ycbcr(image)
reference_conv = kornia.color.rgb_to_ycbcr(reference)
back_conversion = kornia.color.ycbcr_to_rgb
elif color_space == "LUV":
image_conv = kornia.color.rgb_to_luv(image)
reference_conv = kornia.color.rgb_to_luv(reference)
back_conversion = kornia.color.luv_to_rgb
elif color_space == "YUV":
image_conv = kornia.color.rgb_to_yuv(image)
reference_conv = kornia.color.rgb_to_yuv(reference)
back_conversion = kornia.color.yuv_to_rgb
elif color_space == "XYZ":
image_conv = kornia.color.rgb_to_xyz(image)
reference_conv = kornia.color.rgb_to_xyz(reference)
back_conversion = kornia.color.xyz_to_rgb
else: # RGB
image_conv = image
reference_conv = reference
back_conversion = lambda x: x
# Compute statistics
reference_mean, reference_std = compute_mean_std(reference_conv)
image_mean, image_std = compute_mean_std(image_conv)
# Apply color matching
matched = torch.nan_to_num((image_conv - image_mean) / image_std) * reference_std + reference_mean
matched = factor * matched + (1 - factor) * image_conv
# Convert back to RGB
matched = back_conversion(matched)
# Convert back to HWC format and to uint8
output = matched.squeeze(0).permute(1, 2, 0)
output = (output.clamp(0, 1).numpy() * 255).astype(np.uint8)
return output
def analyze_color_statistics(image):
l, a, b = kornia.color.rgb_to_lab(image).chunk(3, dim=1)
mean_l = l.mean()
std_l = l.std()
mean_a = a.mean()
mean_b = b.mean()
std_ab = torch.sqrt(a.var() + b.var())
return mean_l, std_l, mean_a, mean_b, std_ab
def apply_adobe_color_match(image, reference, color_space, luminance_factor, color_intensity_factor, fade_factor, neutralization_factor):
if image is None or reference is None:
return None
# Convert to torch tensors and normalize
image = torch.from_numpy(image).float() / 255.0
reference = torch.from_numpy(reference).float() / 255.0
# Add batch dimension and rearrange to BCHW
image = image.unsqueeze(0).permute(0, 3, 1, 2)
reference = reference.unsqueeze(0).permute(0, 3, 1, 2)
# Analyze color statistics
source_stats = analyze_color_statistics(reference)
dest_stats = analyze_color_statistics(image)
# Convert to LAB
l, a, b = kornia.color.rgb_to_lab(image).chunk(3, dim=1)
# Unpack statistics
src_mean_l, src_std_l, src_mean_a, src_mean_b, src_std_ab = source_stats
dest_mean_l, dest_std_l, dest_mean_a, dest_mean_b, dest_std_ab = dest_stats
# Apply transformations
l_new = (l - dest_mean_l) * (src_std_l / dest_std_l) * luminance_factor + src_mean_l
# Neutralize color cast
a = a - neutralization_factor * dest_mean_a
b = b - neutralization_factor * dest_mean_b
# Adjust color intensity
a_new = a * (src_std_ab / dest_std_ab) * color_intensity_factor
b_new = b * (src_std_ab / dest_std_ab) * color_intensity_factor
# Combine channels
lab_new = torch.cat([l_new, a_new, b_new], dim=1)
# Convert back to RGB
rgb_new = kornia.color.lab_to_rgb(lab_new)
# Apply fade factor
result = fade_factor * rgb_new + (1 - fade_factor) * image
# Convert back to HWC format and to uint8
output = result.squeeze(0).permute(1, 2, 0)
output = (output.clamp(0, 1).numpy() * 255).astype(np.uint8)
return output
def create_color_match_tab():
with gr.Tab("Color Matching"):
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", height=256)
reference_image = gr.Image(label="Reference Image", height=256)
with gr.Tabs():
with gr.Tab("Standard"):
color_space = gr.Dropdown(
choices=["LAB", "YCbCr", "RGB", "LUV", "YUV", "XYZ"],
value="LAB",
label="Color Space"
)
factor = gr.Slider(
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.05,
label="Factor"
)
standard_btn = gr.Button("Apply Standard Color Match")
with gr.Tab("Adobe Style"):
adobe_color_space = gr.Dropdown(
choices=["RGB", "LAB"],
value="LAB",
label="Color Space"
)
luminance_factor = gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.05,
label="Luminance Factor"
)
color_intensity_factor = gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.05,
label="Color Intensity Factor"
)
fade_factor = gr.Slider(
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.05,
label="Fade Factor"
)
neutralization_factor = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.05,
label="Neutralization Factor"
)
adobe_btn = gr.Button("Apply Adobe Style Color Match")
with gr.Column():
output_image = gr.Image(label="Color Matched Image")
standard_btn.click(
fn=apply_color_match,
inputs=[input_image, reference_image, color_space, factor],
outputs=output_image
)
adobe_btn.click(
fn=apply_adobe_color_match,
inputs=[
input_image, reference_image, adobe_color_space,
luminance_factor, color_intensity_factor, fade_factor, neutralization_factor
],
outputs=output_image
)