eee515 / app.py
kvinod15's picture
Update app.py
4b0b756 verified
raw
history blame
5.85 kB
import io
import numpy as np
import torch
from PIL import Image, ImageFilter
from torchvision import transforms
import gradio as gr
from transformers import AutoModelForImageSegmentation, pipeline
# ----------------------------
# Global Setup and Model Loading
# ----------------------------
# Set device (GPU if available, else CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the segmentation model (RMBG-2.0)
segmentation_model = AutoModelForImageSegmentation.from_pretrained(
'briaai/RMBG-2.0',
trust_remote_code=True
)
segmentation_model.to(device)
segmentation_model.eval()
# Define the image transformation for segmentation (resize to 512x512)
image_size = (512, 512)
segmentation_transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the depth estimation pipeline (Depth-Anything)
depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
# ----------------------------
# Processing Functions
# ----------------------------
def segment_and_blur_background(input_image: Image.Image, blur_radius: int = 15, threshold: float = 0.5) -> Image.Image:
"""
Applies segmentation using the RMBG-2.0 model and then uses the segmentation mask
to composite a Gaussian-blurred background with a sharp foreground.
"""
# Ensure the image is in RGB and get its original dimensions
image = input_image.convert("RGB")
orig_width, orig_height = image.size
# Preprocess image for segmentation
input_tensor = segmentation_transform(image).unsqueeze(0).to(device)
# Run inference on the segmentation model
with torch.no_grad():
preds = segmentation_model(input_tensor)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
# Create a binary mask using the threshold
binary_mask = (pred > threshold).float()
mask_pil = transforms.ToPILImage()(binary_mask).convert("L")
# Convert grayscale mask to pure binary (0 or 255)
mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0)
# Resize mask back to the original image dimensions
mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR)
# Apply Gaussian blur to the entire image for background
blurred_image = image.filter(ImageFilter.GaussianBlur(blur_radius))
# Composite the original image (foreground) with the blurred image (background) using the mask
final_image = Image.composite(image, blurred_image, mask_pil)
return final_image
def depth_based_lens_blur(input_image: Image.Image, max_blur: float = 2, num_bands: int = 40, invert_depth: bool = False) -> Image.Image:
"""
Applies a depth-based blur effect using a depth map produced by Depth-Anything.
The effect simulates a lens blur by applying different blur strengths in depth bands.
"""
# Resize the input image to 512x512 for the depth estimation model
image_resized = input_image.resize((512, 512))
# Run depth estimation to obtain the depth map (as a PIL image)
results = depth_pipeline(image_resized)
depth_map_image = results['depth']
# Convert the depth map to a NumPy array and normalize to [0, 1]
depth_array = np.array(depth_map_image, dtype=np.float32)
d_min, d_max = depth_array.min(), depth_array.max()
depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8)
if invert_depth:
depth_norm = 1.0 - depth_norm
# Convert the resized image to RGBA for compositing
orig_rgba = image_resized.convert("RGBA")
final_image = orig_rgba.copy()
# Divide the normalized depth range into bands
band_edges = np.linspace(0, 1, num_bands + 1)
for i in range(num_bands):
band_min = band_edges[i]
band_max = band_edges[i + 1]
# Use the midpoint of the band to determine the blur strength.
mid = (band_min + band_max) / 2.0
blur_radius_band = (1 - mid) * max_blur
# Create a blurred version of the image for this band.
blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius_band))
# Create a mask for pixels whose normalized depth falls within this band.
band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255
band_mask_pil = Image.fromarray(band_mask, mode="L")
# Composite the blurred version with the current final image using the band mask.
final_image = Image.composite(blurred_version, final_image, band_mask_pil)
# Return the final composited image as RGB.
return final_image.convert("RGB")
def process_image(input_image: Image.Image, effect: str) -> Image.Image:
"""
Dispatch function to apply the selected effect:
- "Gaussian Blur Background": uses segmentation and Gaussian blur.
- "Depth-based Lens Blur": applies depth-based blur using the estimated depth map.
"""
if effect == "Gaussian Blur Background":
return segment_and_blur_background(input_image)
elif effect == "Depth-based Lens Blur":
return depth_based_lens_blur(input_image)
else:
return input_image
# ----------------------------
# Gradio Interface
# ----------------------------
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Radio(choices=["Gaussian Blur Background", "Depth-based Lens Blur"], label="Select Effect")
],
outputs=gr.Image(type="pil", label="Output Image"),
title="Blur Effects Demo",
description=(
"Upload an image and choose an effect: "
"apply segmentation + Gaussian blurred background, or a depth-based lens blur effect."
)
)
if __name__ == "__main__":
iface.launch()