File size: 5,851 Bytes
4b0b756
212a439
4b0b756
 
 
212a439
4b0b756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212a439
4b0b756
 
 
 
 
 
 
212a439
4b0b756
 
 
 
 
212a439
4b0b756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212a439
4b0b756
212a439
4b0b756
212a439
4b0b756
 
 
 
 
 
212a439
 
 
 
aa66451
 
212a439
aa66451
212a439
4b0b756
 
 
 
212a439
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import io
import numpy as np
import torch
from PIL import Image, ImageFilter
from torchvision import transforms
import gradio as gr
from transformers import AutoModelForImageSegmentation, pipeline

# ----------------------------
# Global Setup and Model Loading
# ----------------------------

# Set device (GPU if available, else CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the segmentation model (RMBG-2.0)
segmentation_model = AutoModelForImageSegmentation.from_pretrained(
    'briaai/RMBG-2.0',
    trust_remote_code=True
)
segmentation_model.to(device)
segmentation_model.eval()

# Define the image transformation for segmentation (resize to 512x512)
image_size = (512, 512)
segmentation_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the depth estimation pipeline (Depth-Anything)
depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")


# ----------------------------
# Processing Functions
# ----------------------------

def segment_and_blur_background(input_image: Image.Image, blur_radius: int = 15, threshold: float = 0.5) -> Image.Image:
    """
    Applies segmentation using the RMBG-2.0 model and then uses the segmentation mask
    to composite a Gaussian-blurred background with a sharp foreground.
    """
    # Ensure the image is in RGB and get its original dimensions
    image = input_image.convert("RGB")
    orig_width, orig_height = image.size

    # Preprocess image for segmentation
    input_tensor = segmentation_transform(image).unsqueeze(0).to(device)
    
    # Run inference on the segmentation model
    with torch.no_grad():
        preds = segmentation_model(input_tensor)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    
    # Create a binary mask using the threshold
    binary_mask = (pred > threshold).float()
    mask_pil = transforms.ToPILImage()(binary_mask).convert("L")
    # Convert grayscale mask to pure binary (0 or 255)
    mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0)
    # Resize mask back to the original image dimensions
    mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR)
    
    # Apply Gaussian blur to the entire image for background
    blurred_image = image.filter(ImageFilter.GaussianBlur(blur_radius))
    # Composite the original image (foreground) with the blurred image (background) using the mask
    final_image = Image.composite(image, blurred_image, mask_pil)
    return final_image


def depth_based_lens_blur(input_image: Image.Image, max_blur: float = 2, num_bands: int = 40, invert_depth: bool = False) -> Image.Image:
    """
    Applies a depth-based blur effect using a depth map produced by Depth-Anything.
    The effect simulates a lens blur by applying different blur strengths in depth bands.
    """
    # Resize the input image to 512x512 for the depth estimation model
    image_resized = input_image.resize((512, 512))
    
    # Run depth estimation to obtain the depth map (as a PIL image)
    results = depth_pipeline(image_resized)
    depth_map_image = results['depth']
    
    # Convert the depth map to a NumPy array and normalize to [0, 1]
    depth_array = np.array(depth_map_image, dtype=np.float32)
    d_min, d_max = depth_array.min(), depth_array.max()
    depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8)
    if invert_depth:
        depth_norm = 1.0 - depth_norm

    # Convert the resized image to RGBA for compositing
    orig_rgba = image_resized.convert("RGBA")
    final_image = orig_rgba.copy()
    
    # Divide the normalized depth range into bands
    band_edges = np.linspace(0, 1, num_bands + 1)
    for i in range(num_bands):
        band_min = band_edges[i]
        band_max = band_edges[i + 1]
        # Use the midpoint of the band to determine the blur strength.
        mid = (band_min + band_max) / 2.0
        blur_radius_band = (1 - mid) * max_blur
        
        # Create a blurred version of the image for this band.
        blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius_band))
        
        # Create a mask for pixels whose normalized depth falls within this band.
        band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255
        band_mask_pil = Image.fromarray(band_mask, mode="L")
        
        # Composite the blurred version with the current final image using the band mask.
        final_image = Image.composite(blurred_version, final_image, band_mask_pil)
    
    # Return the final composited image as RGB.
    return final_image.convert("RGB")


def process_image(input_image: Image.Image, effect: str) -> Image.Image:
    """
    Dispatch function to apply the selected effect:
    - "Gaussian Blur Background": uses segmentation and Gaussian blur.
    - "Depth-based Lens Blur": applies depth-based blur using the estimated depth map.
    """
    if effect == "Gaussian Blur Background":
        return segment_and_blur_background(input_image)
    elif effect == "Depth-based Lens Blur":
        return depth_based_lens_blur(input_image)
    else:
        return input_image


# ----------------------------
# Gradio Interface
# ----------------------------

iface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Radio(choices=["Gaussian Blur Background", "Depth-based Lens Blur"], label="Select Effect")
    ],
    outputs=gr.Image(type="pil", label="Output Image"),
    title="Blur Effects Demo",
    description=(
        "Upload an image and choose an effect: "
        "apply segmentation + Gaussian blurred background, or a depth-based lens blur effect."
    )
)

if __name__ == "__main__":
    iface.launch()