Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
from PIL import Image, ImageFilter | |
import torch | |
import gradio as gr | |
from torchvision import transforms | |
from transformers import ( | |
AutoModelForImageSegmentation, | |
DepthProImageProcessorFast, | |
DepthProForDepthEstimation, | |
) | |
# Set device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# ----------------------------- | |
# Load Segmentation Model (RMBG-2.0 by briaai) | |
# ----------------------------- | |
seg_model = AutoModelForImageSegmentation.from_pretrained( | |
"briaai/RMBG-2.0", trust_remote_code=True | |
) | |
# Set higher precision for matmul if desired | |
torch.set_float32_matmul_precision(["high", "highest"][0]) | |
seg_model.to(device) | |
seg_model.eval() | |
# Define segmentation image size and transform | |
seg_image_size = (1024, 1024) | |
seg_transform = transforms.Compose([ | |
transforms.Resize(seg_image_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# ----------------------------- | |
# Load Depth Estimation Model (DepthPro by Apple) | |
# ----------------------------- | |
depth_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf") | |
depth_model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf") | |
depth_model.to(device) | |
depth_model.eval() | |
# ----------------------------- | |
# Define the Segmentation-Based Blur Effect | |
# ----------------------------- | |
def segmentation_blur_effect(input_image: Image.Image): | |
""" | |
Creates a segmentation mask using RMBG-2.0 and applies a Gaussian blur (sigma=15) | |
to the background while keeping the foreground sharp. | |
Returns: | |
- final segmented and blurred image (PIL Image) | |
- segmentation mask (PIL Image) | |
- blurred background image (PIL Image) [optional display] | |
""" | |
# Resize input for segmentation processing | |
imageResized = input_image.resize(seg_image_size) | |
input_tensor = seg_transform(imageResized).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
preds = seg_model(input_tensor)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
# Convert predicted mask to a PIL image and resize to original input size | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(input_image.size) | |
# Create a binary mask (convert to grayscale, then threshold) | |
mask_np = np.array(mask.convert("L")) | |
_, maskBinary = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY) | |
# Convert the resized image to an OpenCV BGR array | |
img = cv2.cvtColor(np.array(imageResized), cv2.COLOR_RGB2BGR) | |
# Apply Gaussian blur (sigmaX=15, sigmaY=15) | |
blurredBg = cv2.GaussianBlur(np.array(imageResized), (0, 0), sigmaX=15, sigmaY=15) | |
# Create the inverse mask and convert to 3 channels | |
maskInv = cv2.bitwise_not(maskBinary) | |
maskInv3 = cv2.cvtColor(maskInv, cv2.COLOR_GRAY2BGR) | |
# Extract the foreground and background separately | |
foreground = cv2.bitwise_and(img, cv2.bitwise_not(maskInv3)) | |
background = cv2.bitwise_and(blurredBg, maskInv3) | |
# Combine the two components | |
finalImg = cv2.add(cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), background) | |
finalImg_pil = Image.fromarray(finalImg) | |
blurredBg_pil = Image.fromarray(cv2.cvtColor(blurredBg, cv2.COLOR_BGR2RGB)) | |
return finalImg_pil, mask, blurredBg_pil | |
# ----------------------------- | |
# Define the Depth-Based Lens Blur Effect | |
# ----------------------------- | |
def lens_blur_effect(input_image: Image.Image): | |
""" | |
Uses DepthPro to estimate a depth map and applies a dynamic lens blur effect | |
by precomputing three versions of the image (foreground, middleground, background) | |
with increasing blur. Regions are blended based on the estimated depth. | |
Returns: | |
- Depth map (PIL Image) | |
- Final lens-blurred image (PIL Image) | |
- Foreground mask (PIL Image) | |
- Middleground mask (PIL Image) | |
- Background mask (PIL Image) | |
""" | |
# Process the image with the depth estimation model | |
inputs = depth_processor(images=input_image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = depth_model(**inputs) | |
post_processed_output = depth_processor.post_process_depth_estimation( | |
outputs, target_sizes=[(input_image.height, input_image.width)] | |
) | |
depth = post_processed_output[0]["predicted_depth"] | |
# Normalize depth to [0, 255] | |
depth = (depth - depth.min()) / (depth.max() - depth.min()) | |
depth = depth * 255. | |
depth = depth.detach().cpu().numpy() | |
depth_map = depth.astype(np.uint8) | |
depthImg = Image.fromarray(depth_map) | |
# Convert input image to OpenCV BGR format | |
img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) | |
# Precompute three blurred versions of the image | |
img_foreground = img.copy() # No blur for foreground | |
img_middleground = cv2.GaussianBlur(img, (0, 0), sigmaX=7, sigmaY=7) | |
img_background = cv2.GaussianBlur(img, (0, 0), sigmaX=15, sigmaY=15) | |
# Define depth thresholds (using 1/3 and 2/3 of 255) | |
threshold1 = 255 / 3 # ~85 | |
threshold2 = 2 * 255 / 3 # ~170 | |
# Create masks for the three regions based on depth | |
mask_fg = (depth_map < threshold1).astype(np.float32) | |
mask_mg = ((depth_map >= threshold1) & (depth_map < threshold2)).astype(np.float32) | |
mask_bg = (depth_map >= threshold2).astype(np.float32) | |
# Expand masks to 3 channels to match image dimensions | |
mask_fg_3 = np.stack([mask_fg]*3, axis=-1) | |
mask_mg_3 = np.stack([mask_mg]*3, axis=-1) | |
mask_bg_3 = np.stack([mask_bg]*3, axis=-1) | |
# Combine the images using the masks (vectorized blending) | |
final_img = (img_foreground * mask_fg_3 + | |
img_middleground * mask_mg_3 + | |
img_background * mask_bg_3).astype(np.uint8) | |
final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB) | |
lensBlurImage = Image.fromarray(final_img_rgb) | |
# Create mask images (scaled to 0-255) | |
mask_fg_img = Image.fromarray((mask_fg * 255).astype(np.uint8)) | |
mask_mg_img = Image.fromarray((mask_mg * 255).astype(np.uint8)) | |
mask_bg_img = Image.fromarray((mask_bg * 255).astype(np.uint8)) | |
return depthImg, lensBlurImage, mask_fg_img, mask_mg_img, mask_bg_img | |
# ----------------------------- | |
# Gradio App: Process Image and Display Multiple Effects | |
# ----------------------------- | |
def process_image(input_image: Image.Image): | |
""" | |
Processes the uploaded image to generate: | |
1. Segmentation-based Gaussian blur effect. | |
2. Segmentation mask. | |
3. Depth map. | |
4. Depth-based lens blur effect. | |
5. Depth-based masks for foreground, middleground, and background. | |
""" | |
seg_blur, seg_mask, _ = segmentation_blur_effect(input_image) | |
depth_map_img, lens_blur_img, mask_fg_img, mask_mg_img, mask_bg_img = lens_blur_effect(input_image) | |
return ( | |
seg_blur, | |
seg_mask, | |
depth_map_img, | |
lens_blur_img, | |
mask_fg_img, | |
mask_mg_img, | |
mask_bg_img | |
) | |
title = "Blur Effects: Gaussian Blur & Depth-Based Lens Blur" | |
description = ( | |
"Upload an image to apply two distinct effects:\n\n" | |
"1. A segmentation-based Gaussian blur that blurs the background (using RMBG-2.0).\n" | |
"2. A depth-based lens blur effect that simulates realistic lens blur based on depth (using DepthPro).\n\n" | |
"Outputs include the blurred image, segmentation mask, depth map, lens-blurred image, and depth masks." | |
) | |
demo = gr.Interface( | |
fn=process_image, | |
inputs=gr.Image(type="pil", label="Input Image"), | |
outputs=[ | |
gr.Image(type="pil", label="Segmentation-Based Blur"), | |
gr.Image(type="pil", label="Segmentation Mask"), | |
gr.Image(type="pil", label="Depth Map"), | |
gr.Image(type="pil", label="Depth-Based Lens Blur"), | |
gr.Image(type="pil", label="Foreground Depth Mask"), | |
gr.Image(type="pil", label="Middleground Depth Mask"), | |
gr.Image(type="pil", label="Background Depth Mask") | |
], | |
title=title, | |
description=description, | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |