|
import io |
|
import numpy as np |
|
import torch |
|
from PIL import Image, ImageFilter |
|
from torchvision import transforms |
|
import gradio as gr |
|
from transformers import AutoModelForImageSegmentation, pipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
segmentation_model = AutoModelForImageSegmentation.from_pretrained( |
|
'briaai/RMBG-2.0', |
|
trust_remote_code=True |
|
) |
|
segmentation_model.to(device) |
|
segmentation_model.eval() |
|
|
|
|
|
image_size = (512, 512) |
|
segmentation_transform = transforms.Compose([ |
|
transforms.Resize(image_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf") |
|
|
|
|
|
|
|
|
|
|
|
|
|
def segment_and_blur_background(input_image: Image.Image, blur_radius: int = 15, threshold: float = 0.5) -> Image.Image: |
|
""" |
|
Applies segmentation using the RMBG-2.0 model and then uses the segmentation mask |
|
to composite a Gaussian-blurred background with a sharp foreground. |
|
""" |
|
|
|
image = input_image.convert("RGB") |
|
orig_width, orig_height = image.size |
|
|
|
|
|
input_tensor = segmentation_transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
preds = segmentation_model(input_tensor)[-1].sigmoid().cpu() |
|
pred = preds[0].squeeze() |
|
|
|
|
|
binary_mask = (pred > threshold).float() |
|
mask_pil = transforms.ToPILImage()(binary_mask).convert("L") |
|
|
|
mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0) |
|
|
|
mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR) |
|
|
|
|
|
blurred_image = image.filter(ImageFilter.GaussianBlur(blur_radius)) |
|
|
|
final_image = Image.composite(image, blurred_image, mask_pil) |
|
return final_image |
|
|
|
|
|
def depth_based_lens_blur(input_image: Image.Image, max_blur: float = 2, num_bands: int = 40, invert_depth: bool = False) -> Image.Image: |
|
""" |
|
Applies a depth-based blur effect using a depth map produced by Depth-Anything. |
|
The effect simulates a lens blur by applying different blur strengths in depth bands. |
|
""" |
|
|
|
image_resized = input_image.resize((512, 512)) |
|
|
|
|
|
results = depth_pipeline(image_resized) |
|
depth_map_image = results['depth'] |
|
|
|
|
|
depth_array = np.array(depth_map_image, dtype=np.float32) |
|
d_min, d_max = depth_array.min(), depth_array.max() |
|
depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8) |
|
if invert_depth: |
|
depth_norm = 1.0 - depth_norm |
|
|
|
|
|
orig_rgba = image_resized.convert("RGBA") |
|
final_image = orig_rgba.copy() |
|
|
|
|
|
band_edges = np.linspace(0, 1, num_bands + 1) |
|
for i in range(num_bands): |
|
band_min = band_edges[i] |
|
band_max = band_edges[i + 1] |
|
|
|
mid = (band_min + band_max) / 2.0 |
|
blur_radius_band = (1 - mid) * max_blur |
|
|
|
|
|
blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius_band)) |
|
|
|
|
|
band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255 |
|
band_mask_pil = Image.fromarray(band_mask, mode="L") |
|
|
|
|
|
final_image = Image.composite(blurred_version, final_image, band_mask_pil) |
|
|
|
|
|
return final_image.convert("RGB") |
|
|
|
|
|
def process_image(input_image: Image.Image, effect: str) -> Image.Image: |
|
""" |
|
Dispatch function to apply the selected effect: |
|
- "Gaussian Blur Background": uses segmentation and Gaussian blur. |
|
- "Depth-based Lens Blur": applies depth-based blur using the estimated depth map. |
|
""" |
|
if effect == "Gaussian Blur Background": |
|
return segment_and_blur_background(input_image) |
|
elif effect == "Depth-based Lens Blur": |
|
return depth_based_lens_blur(input_image) |
|
else: |
|
return input_image |
|
|
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_image, |
|
inputs=[ |
|
gr.Image(type="pil", label="Input Image"), |
|
gr.Radio(choices=["Gaussian Blur Background", "Depth-based Lens Blur"], label="Select Effect") |
|
], |
|
outputs=gr.Image(type="pil", label="Output Image"), |
|
title="Blur Effects Demo", |
|
description=( |
|
"Upload an image and choose an effect: " |
|
"apply segmentation + Gaussian blurred background, or a depth-based lens blur effect." |
|
) |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|