TRELLIS-Single3D / sharpen_processor.py
gokaygokay's picture
add tab
c63008c
raw
history blame
5.01 kB
import gradio as gr
import torch
import torch.nn.functional as F
import cv2
import kornia
import numpy as np
def min_(items):
current = items[0]
for item in items[1:]:
current = torch.minimum(current, item)
return current
def max_(items):
current = items[0]
for item in items[1:]:
current = torch.maximum(current, item)
return current
def apply_cas(image, amount):
if image is None:
return None
# Convert to torch tensor and normalize
image = torch.from_numpy(image).float() / 255.0
# Add batch dimension and rearrange to BCHW
image = image.unsqueeze(0).permute(0, 3, 1, 2)
epsilon = 1e-5
img = F.pad(image, pad=(1, 1, 1, 1))
a = img[..., :-2, :-2]
b = img[..., :-2, 1:-1]
c = img[..., :-2, 2:]
d = img[..., 1:-1, :-2]
e = img[..., 1:-1, 1:-1]
f = img[..., 1:-1, 2:]
g = img[..., 2:, :-2]
h = img[..., 2:, 1:-1]
i = img[..., 2:, 2:]
cross = (b, d, e, f, h)
mn = min_(cross)
mx = max_(cross)
diag = (a, c, g, i)
mn2 = min_(diag)
mx2 = max_(diag)
mx = mx + mx2
mn = mn + mn2
inv_mx = torch.reciprocal(mx + epsilon)
amp = inv_mx * torch.minimum(mn, (2 - mx))
amp = torch.sqrt(amp)
w = - amp * (amount * (1/5 - 1/8) + 1/8)
div = torch.reciprocal(1 + 4*w)
output = ((b + d + f + h)*w + e) * div
output = output.clamp(0, 1)
# Convert back to HWC format and to uint8
output = output.squeeze(0).permute(1, 2, 0)
output = (output.numpy() * 255).astype(np.uint8)
return output
def apply_smart_sharpen(image, noise_radius, preserve_edges, sharpen, ratio):
if image is None:
return None
# Convert to torch tensor and normalize
image = torch.from_numpy(image).float() / 255.0
if preserve_edges > 0:
preserve_edges = max(1 - preserve_edges, 0.05)
# Apply bilateral filter for noise reduction
if noise_radius > 1:
sigma = 0.3 * ((noise_radius - 1) * 0.5 - 1) + 0.8
blurred = cv2.bilateralFilter(image.numpy(), noise_radius, preserve_edges, sigma)
blurred = torch.from_numpy(blurred)
else:
blurred = image
# Apply sharpening
if sharpen > 0:
img_chw = image.permute(2, 0, 1).unsqueeze(0) # Add batch dimension
sharpened = kornia.enhance.sharpness(img_chw, sharpen).squeeze(0).permute(1, 2, 0)
else:
sharpened = image
# Blend results
result = ratio * sharpened + (1 - ratio) * blurred
result = torch.clamp(result, 0, 1)
# Convert back to uint8
output = (result.numpy() * 255).astype(np.uint8)
return output
def create_sharpen_tab():
with gr.Tab("Sharpening"):
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", height=256)
with gr.Tabs():
with gr.Tab("CAS"):
cas_amount = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.8,
step=0.05,
label="Amount"
)
cas_btn = gr.Button("Apply CAS")
with gr.Tab("Smart Sharpen"):
noise_radius = gr.Slider(
minimum=1,
maximum=25,
value=7,
step=1,
label="Noise Radius"
)
preserve_edges = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.75,
step=0.05,
label="Preserve Edges"
)
sharpen = gr.Slider(
minimum=0.0,
maximum=25.0,
value=5.0,
step=0.5,
label="Sharpen Amount"
)
ratio = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.1,
label="Blend Ratio"
)
smart_btn = gr.Button("Apply Smart Sharpen")
with gr.Column():
output_image = gr.Image(label="Sharpened Image")
cas_btn.click(
fn=apply_cas,
inputs=[input_image, cas_amount],
outputs=output_image
)
smart_btn.click(
fn=apply_smart_sharpen,
inputs=[input_image, noise_radius, preserve_edges, sharpen, ratio],
outputs=output_image
)