Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
import cv2 | |
import kornia | |
import numpy as np | |
def min_(items): | |
current = items[0] | |
for item in items[1:]: | |
current = torch.minimum(current, item) | |
return current | |
def max_(items): | |
current = items[0] | |
for item in items[1:]: | |
current = torch.maximum(current, item) | |
return current | |
def apply_cas(image, amount): | |
if image is None: | |
return None | |
# Convert to torch tensor and normalize | |
image = torch.from_numpy(image).float() / 255.0 | |
# Add batch dimension and rearrange to BCHW | |
image = image.unsqueeze(0).permute(0, 3, 1, 2) | |
epsilon = 1e-5 | |
img = F.pad(image, pad=(1, 1, 1, 1)) | |
a = img[..., :-2, :-2] | |
b = img[..., :-2, 1:-1] | |
c = img[..., :-2, 2:] | |
d = img[..., 1:-1, :-2] | |
e = img[..., 1:-1, 1:-1] | |
f = img[..., 1:-1, 2:] | |
g = img[..., 2:, :-2] | |
h = img[..., 2:, 1:-1] | |
i = img[..., 2:, 2:] | |
cross = (b, d, e, f, h) | |
mn = min_(cross) | |
mx = max_(cross) | |
diag = (a, c, g, i) | |
mn2 = min_(diag) | |
mx2 = max_(diag) | |
mx = mx + mx2 | |
mn = mn + mn2 | |
inv_mx = torch.reciprocal(mx + epsilon) | |
amp = inv_mx * torch.minimum(mn, (2 - mx)) | |
amp = torch.sqrt(amp) | |
w = - amp * (amount * (1/5 - 1/8) + 1/8) | |
div = torch.reciprocal(1 + 4*w) | |
output = ((b + d + f + h)*w + e) * div | |
output = output.clamp(0, 1) | |
# Convert back to HWC format and to uint8 | |
output = output.squeeze(0).permute(1, 2, 0) | |
output = (output.numpy() * 255).astype(np.uint8) | |
return output | |
def apply_smart_sharpen(image, noise_radius, preserve_edges, sharpen, ratio): | |
if image is None: | |
return None | |
# Convert to torch tensor and normalize | |
image = torch.from_numpy(image).float() / 255.0 | |
if preserve_edges > 0: | |
preserve_edges = max(1 - preserve_edges, 0.05) | |
# Apply bilateral filter for noise reduction | |
if noise_radius > 1: | |
sigma = 0.3 * ((noise_radius - 1) * 0.5 - 1) + 0.8 | |
blurred = cv2.bilateralFilter(image.numpy(), noise_radius, preserve_edges, sigma) | |
blurred = torch.from_numpy(blurred) | |
else: | |
blurred = image | |
# Apply sharpening | |
if sharpen > 0: | |
img_chw = image.permute(2, 0, 1).unsqueeze(0) # Add batch dimension | |
sharpened = kornia.enhance.sharpness(img_chw, sharpen).squeeze(0).permute(1, 2, 0) | |
else: | |
sharpened = image | |
# Blend results | |
result = ratio * sharpened + (1 - ratio) * blurred | |
result = torch.clamp(result, 0, 1) | |
# Convert back to uint8 | |
output = (result.numpy() * 255).astype(np.uint8) | |
return output | |
def create_sharpen_tab(): | |
with gr.Tab("Sharpening"): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", height=256) | |
with gr.Tabs(): | |
with gr.Tab("CAS"): | |
cas_amount = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.8, | |
step=0.05, | |
label="Amount" | |
) | |
cas_btn = gr.Button("Apply CAS") | |
with gr.Tab("Smart Sharpen"): | |
noise_radius = gr.Slider( | |
minimum=1, | |
maximum=25, | |
value=7, | |
step=1, | |
label="Noise Radius" | |
) | |
preserve_edges = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.75, | |
step=0.05, | |
label="Preserve Edges" | |
) | |
sharpen = gr.Slider( | |
minimum=0.0, | |
maximum=25.0, | |
value=5.0, | |
step=0.5, | |
label="Sharpen Amount" | |
) | |
ratio = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.5, | |
step=0.1, | |
label="Blend Ratio" | |
) | |
smart_btn = gr.Button("Apply Smart Sharpen") | |
with gr.Column(): | |
output_image = gr.Image(label="Sharpened Image") | |
cas_btn.click( | |
fn=apply_cas, | |
inputs=[input_image, cas_amount], | |
outputs=output_image | |
) | |
smart_btn.click( | |
fn=apply_smart_sharpen, | |
inputs=[input_image, noise_radius, preserve_edges, sharpen, ratio], | |
outputs=output_image | |
) |