File size: 6,168 Bytes
4b0b756 212a439 4b0b756 212a439 4b0b756 ae8d774 4b0b756 ae8d774 4b0b756 ae8d774 4b0b756 ae8d774 4b0b756 212a439 ae8d774 4b0b756 212a439 ae8d774 4b0b756 212a439 4b0b756 ae8d774 4b0b756 ae8d774 4b0b756 ae8d774 4b0b756 ae8d774 4b0b756 ae8d774 4b0b756 ae8d774 a390814 4b0b756 212a439 ae8d774 212a439 ae8d774 212a439 4b0b756 ae8d774 4b0b756 212a439 ae8d774 4b0b756 212a439 ae8d774 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import io
import numpy as np
import torch
from PIL import Image, ImageFilter
from torchvision import transforms
import gradio as gr
from transformers import AutoModelForImageSegmentation, pipeline
# ----------------------------
# Global Setup and Model Loading
# ----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the segmentation model (RMBG-2.0)
segmentation_model = AutoModelForImageSegmentation.from_pretrained(
'briaai/RMBG-2.0',
trust_remote_code=True
)
segmentation_model.to(device)
segmentation_model.eval()
# Transformation for segmentation (resizes to 512 for the model input)
image_size = (512, 512)
segmentation_transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the depth estimation pipeline (Depth-Anything)
depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
# ----------------------------
# Processing Functions
# ----------------------------
def segment_and_blur_background(input_image: Image.Image, blur_strength: int = 15, threshold: float = 0.5) -> Image.Image:
"""
Applies segmentation using the RMBG-2.0 model and composites the original image with
a Gaussian-blurred background based on an adjustable mask sensitivity threshold.
"""
image = input_image.convert("RGB")
orig_width, orig_height = image.size
# Preprocess image for segmentation (resize only for model inference)
input_tensor = segmentation_transform(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = segmentation_model(input_tensor)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
# Create binary mask with adjustable threshold (mask sensitivity)
binary_mask = (pred > threshold).float()
mask_pil = transforms.ToPILImage()(binary_mask).convert("L")
mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0)
mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR)
blurred_image = image.filter(ImageFilter.GaussianBlur(blur_strength))
final_image = Image.composite(image, blurred_image, mask_pil)
return final_image
def depth_based_lens_blur(input_image: Image.Image, max_blur: float = 2, num_bands: int = 40, invert_depth: bool = False) -> Image.Image:
"""
Applies a depth-based blur effect using a depth map produced by Depth-Anything.
The effect simulates a lens blur where the max_blur parameter controls the maximum blur.
This function uses the original input image size.
"""
# Use the original image for depth estimation (no resizing)
image_original = input_image.convert("RGB")
# Obtain depth map using the pipeline (assumes model accepts variable sizes)
results = depth_pipeline(image_original)
depth_map_image = results['depth']
depth_array = np.array(depth_map_image, dtype=np.float32)
d_min, d_max = depth_array.min(), depth_array.max()
depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8)
if invert_depth:
depth_norm = 1.0 - depth_norm
orig_rgba = image_original.convert("RGBA")
final_image = orig_rgba.copy()
band_edges = np.linspace(0, 1, num_bands + 1)
for i in range(num_bands):
band_min = band_edges[i]
band_max = band_edges[i + 1]
mid = (band_min + band_max) / 2.0
blur_radius_band = (1 - mid) * max_blur
blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius_band))
band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255
band_mask_pil = Image.fromarray(band_mask, mode="L")
final_image = Image.composite(blurred_version, final_image, band_mask_pil)
return final_image.convert("RGB")
def process_image(input_image: Image.Image, effect: str, mask_sensitivity: float, blur_strength: float) -> Image.Image:
"""
Applies the selected effect:
- "Gaussian Blur Background": uses segmentation with adjustable mask sensitivity and blur strength.
- "Depth-based Lens Blur": applies depth-based blur with an adjustable maximum blur.
"""
if effect == "Gaussian Blur Background":
return segment_and_blur_background(input_image, blur_strength=int(blur_strength), threshold=mask_sensitivity)
elif effect == "Depth-based Lens Blur":
return depth_based_lens_blur(input_image, max_blur=blur_strength)
else:
return input_image
# ----------------------------
# Gradio Blocks Layout
# ----------------------------
with gr.Blocks(title="Interactive Blur Effects Demo") as demo:
gr.Markdown(
"""
# Interactive Blur Effects Demo
Upload an image and choose an effect below.
For **Gaussian Blur Background**, adjust the mask sensitivity (controls segmentation threshold)
and blur strength (controls Gaussian blur radius).
For **Depth-based Lens Blur**, the blur strength slider sets the maximum blur intensity.
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
effect_choice = gr.Radio(
choices=["Gaussian Blur Background", "Depth-based Lens Blur"],
label="Select Effect",
value="Gaussian Blur Background"
)
mask_sensitivity_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.01,
label="Mask Sensitivity (for segmentation)"
)
blur_strength_slider = gr.Slider(
minimum=0, maximum=30, value=15, step=1,
label="Blur Strength"
)
run_button = gr.Button("Apply Effect")
with gr.Column():
output_image = gr.Image(type="pil", label="Output Image")
run_button.click(
fn=process_image,
inputs=[input_image, effect_choice, mask_sensitivity_slider, blur_strength_slider],
outputs=output_image
)
if __name__ == "__main__":
demo.launch()
|