eee515 / app.py
kvinod15's picture
Update app.py
e86315f verified
raw
history blame
6.65 kB
import io
import numpy as np
import torch
from PIL import Image, ImageFilter
from torchvision import transforms
import gradio as gr
from transformers import AutoModelForImageSegmentation, pipeline
# ----------------------------
# Global Setup and Model Loading
# ----------------------------
# Set device (GPU if available, else CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the segmentation model (RMBG-2.0)
segmentation_model = AutoModelForImageSegmentation.from_pretrained(
'briaai/RMBG-2.0',
trust_remote_code=True
)
segmentation_model.to(device)
segmentation_model.eval()
# Define the image transformation for segmentation (resize to 512x512, then normalize)
image_size = (512, 512)
segmentation_transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the depth estimation pipeline (Depth-Anything)
depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
# ----------------------------
# Processing Functions
# ----------------------------
def segment_and_blur_background(input_image: Image.Image, blur_radius: int = 15, threshold: float = 0.5) -> Image.Image:
"""
Uses the RMBG-2.0 segmentation model to create a binary mask,
then composites a Gaussian-blurred background with the sharp foreground.
The segmentation threshold is adjustable.
"""
# Ensure the image is in RGB and get its original dimensions
image = input_image.convert("RGB")
orig_width, orig_height = image.size
# Preprocess image for segmentation
input_tensor = segmentation_transform(image).unsqueeze(0).to(device)
# Run inference on the segmentation model
with torch.no_grad():
preds = segmentation_model(input_tensor)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
# Create a binary mask using the adjustable threshold
binary_mask = (pred > threshold).float()
mask_pil = transforms.ToPILImage()(binary_mask).convert("L")
# Convert grayscale mask to pure binary (0 or 255)
mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0)
# Resize mask back to the original image dimensions
mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR)
# Apply Gaussian blur to the entire image for background
blurred_image = image.filter(ImageFilter.GaussianBlur(blur_radius))
# Composite the original image (foreground) with the blurred background using the mask
final_image = Image.composite(image, blurred_image, mask_pil)
return final_image
def depth_based_lens_blur(input_image: Image.Image, max_blur: float = 2, num_bands: int = 40, invert_depth: bool = False) -> Image.Image:
"""
Applies a depth-based blur effect using a depth map from Depth-Anything.
The max_blur parameter (controlled by a slider) sets the highest blur intensity.
"""
# Resize the input image to 512x512 for the depth estimation model
image_resized = input_image.resize((512, 512))
# Run depth estimation to obtain the depth map (as a PIL image)
results = depth_pipeline(image_resized)
depth_map_image = results['depth']
# Convert the depth map to a NumPy array and normalize to [0, 1]
depth_array = np.array(depth_map_image, dtype=np.float32)
d_min, d_max = depth_array.min(), depth_array.max()
depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8)
if invert_depth:
depth_norm = 1.0 - depth_norm
# Convert the resized image to RGBA for compositing
orig_rgba = image_resized.convert("RGBA")
final_image = orig_rgba.copy()
# Divide the normalized depth range into bands and apply variable blur
band_edges = np.linspace(0, 1, num_bands + 1)
for i in range(num_bands):
band_min = band_edges[i]
band_max = band_edges[i + 1]
# Use the midpoint of the band to determine the blur strength.
mid = (band_min + band_max) / 2.0
blur_radius_band = (1 - mid) * max_blur
# Create a blurred version of the image for this band.
blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius_band))
# Create a mask for pixels whose normalized depth falls within this band.
band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255
band_mask_pil = Image.fromarray(band_mask, mode="L")
# Composite the blurred version with the current final image using the band mask.
final_image = Image.composite(blurred_version, final_image, band_mask_pil)
# Return the final composited image as RGB.
return final_image.convert("RGB")
def process_image(input_image: Image.Image, effect: str, threshold: float, blur_intensity: float) -> Image.Image:
"""
Dispatch function to apply the selected effect:
- "Gaussian Blur Background": uses segmentation with an adjustable threshold and blur radius.
- "Depth-based Lens Blur": applies depth-based blur with an adjustable maximum blur.
The threshold slider is used only for the segmentation effect.
The blur_intensity slider controls the blur strength in both effects.
"""
if effect == "Gaussian Blur Background":
# For segmentation, use the threshold and blur_intensity (as blur_radius)
return segment_and_blur_background(input_image, blur_radius=int(blur_intensity), threshold=threshold)
elif effect == "Depth-based Lens Blur":
# For depth-based blur, use the blur_intensity as the max blur value.
return depth_based_lens_blur(input_image, max_blur=blur_intensity)
else:
return input_image
# ----------------------------
# Gradio Interface
# ----------------------------
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Radio(choices=["Gaussian Blur Background", "Depth-based Lens Blur"], label="Select Effect"),
gr.Slider(0.0, 1.0, value=0.5, label="Segmentation Threshold (for Gaussian Blur)"),
gr.Slider(0, 30, value=15, step=1, label="Blur Intensity (for both effects)")
],
outputs=gr.Image(type="pil", label="Output Image"),
title="Interactive Blur Effects Demo",
description=(
"Upload an image and choose an effect. For 'Gaussian Blur Background', adjust the segmentation threshold and blur intensity. "
"For 'Depth-based Lens Blur', the blur intensity slider sets the maximum blur based on depth."
)
)
if __name__ == "__main__":
iface.launch()