File size: 6,648 Bytes
4b0b756
212a439
4b0b756
 
 
212a439
4b0b756
 
 
 
 
 
e86315f
4b0b756
 
 
 
 
 
 
 
 
 
e86315f
4b0b756
 
 
 
 
 
 
 
 
 
 
 
 
 
e86315f
4b0b756
e86315f
 
 
4b0b756
e86315f
4b0b756
 
 
e86315f
4b0b756
 
e86315f
4b0b756
 
 
212a439
e86315f
4b0b756
 
e86315f
4b0b756
e86315f
4b0b756
212a439
e86315f
 
 
4b0b756
 
212a439
4b0b756
 
e86315f
 
4b0b756
e86315f
 
4b0b756
e86315f
 
4b0b756
 
e86315f
4b0b756
 
 
 
 
 
e86315f
 
4b0b756
 
e86315f
4b0b756
 
 
 
e86315f
4b0b756
 
 
e86315f
4b0b756
e86315f
 
4b0b756
 
e86315f
 
4b0b756
 
e86315f
4b0b756
 
e86315f
4b0b756
e86315f
 
a390814
e86315f
 
4b0b756
212a439
e86315f
 
212a439
e86315f
 
212a439
4b0b756
 
 
e86315f
4b0b756
212a439
e86315f
 
 
 
 
 
 
 
 
 
 
 
 
4b0b756
e86315f
212a439
 
e86315f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import io
import numpy as np
import torch
from PIL import Image, ImageFilter
from torchvision import transforms
import gradio as gr
from transformers import AutoModelForImageSegmentation, pipeline

# ----------------------------
# Global Setup and Model Loading
# ----------------------------

# Set device (GPU if available, else CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the segmentation model (RMBG-2.0)
segmentation_model = AutoModelForImageSegmentation.from_pretrained(
    'briaai/RMBG-2.0',
    trust_remote_code=True
)
segmentation_model.to(device)
segmentation_model.eval()

# Define the image transformation for segmentation (resize to 512x512, then normalize)
image_size = (512, 512)
segmentation_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the depth estimation pipeline (Depth-Anything)
depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")

# ----------------------------
# Processing Functions
# ----------------------------

def segment_and_blur_background(input_image: Image.Image, blur_radius: int = 15, threshold: float = 0.5) -> Image.Image:
    """
    Uses the RMBG-2.0 segmentation model to create a binary mask,
    then composites a Gaussian-blurred background with the sharp foreground.
    The segmentation threshold is adjustable.
    """
    # Ensure the image is in RGB and get its original dimensions
    image = input_image.convert("RGB")
    orig_width, orig_height = image.size

    # Preprocess image for segmentation
    input_tensor = segmentation_transform(image).unsqueeze(0).to(device)
    
    # Run inference on the segmentation model
    with torch.no_grad():
        preds = segmentation_model(input_tensor)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    
    # Create a binary mask using the adjustable threshold
    binary_mask = (pred > threshold).float()
    mask_pil = transforms.ToPILImage()(binary_mask).convert("L")
    # Convert grayscale mask to pure binary (0 or 255)
    mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0)
    # Resize mask back to the original image dimensions
    mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR)
    
    # Apply Gaussian blur to the entire image for background
    blurred_image = image.filter(ImageFilter.GaussianBlur(blur_radius))
    # Composite the original image (foreground) with the blurred background using the mask
    final_image = Image.composite(image, blurred_image, mask_pil)
    return final_image

def depth_based_lens_blur(input_image: Image.Image, max_blur: float = 2, num_bands: int = 40, invert_depth: bool = False) -> Image.Image:
    """
    Applies a depth-based blur effect using a depth map from Depth-Anything.
    The max_blur parameter (controlled by a slider) sets the highest blur intensity.
    """
    # Resize the input image to 512x512 for the depth estimation model
    image_resized = input_image.resize((512, 512))
    
    # Run depth estimation to obtain the depth map (as a PIL image)
    results = depth_pipeline(image_resized)
    depth_map_image = results['depth']
    
    # Convert the depth map to a NumPy array and normalize to [0, 1]
    depth_array = np.array(depth_map_image, dtype=np.float32)
    d_min, d_max = depth_array.min(), depth_array.max()
    depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8)
    if invert_depth:
        depth_norm = 1.0 - depth_norm

    # Convert the resized image to RGBA for compositing
    orig_rgba = image_resized.convert("RGBA")
    final_image = orig_rgba.copy()
    
    # Divide the normalized depth range into bands and apply variable blur
    band_edges = np.linspace(0, 1, num_bands + 1)
    for i in range(num_bands):
        band_min = band_edges[i]
        band_max = band_edges[i + 1]
        # Use the midpoint of the band to determine the blur strength.
        mid = (band_min + band_max) / 2.0
        blur_radius_band = (1 - mid) * max_blur
        
        # Create a blurred version of the image for this band.
        blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius_band))
        
        # Create a mask for pixels whose normalized depth falls within this band.
        band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255
        band_mask_pil = Image.fromarray(band_mask, mode="L")
        
        # Composite the blurred version with the current final image using the band mask.
        final_image = Image.composite(blurred_version, final_image, band_mask_pil)
    
    # Return the final composited image as RGB.
    return final_image.convert("RGB")

def process_image(input_image: Image.Image, effect: str, threshold: float, blur_intensity: float) -> Image.Image:
    """
    Dispatch function to apply the selected effect:
      - "Gaussian Blur Background": uses segmentation with an adjustable threshold and blur radius.
      - "Depth-based Lens Blur": applies depth-based blur with an adjustable maximum blur.
    The threshold slider is used only for the segmentation effect.
    The blur_intensity slider controls the blur strength in both effects.
    """
    if effect == "Gaussian Blur Background":
        # For segmentation, use the threshold and blur_intensity (as blur_radius)
        return segment_and_blur_background(input_image, blur_radius=int(blur_intensity), threshold=threshold)
    elif effect == "Depth-based Lens Blur":
        # For depth-based blur, use the blur_intensity as the max blur value.
        return depth_based_lens_blur(input_image, max_blur=blur_intensity)
    else:
        return input_image

# ----------------------------
# Gradio Interface
# ----------------------------

iface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Radio(choices=["Gaussian Blur Background", "Depth-based Lens Blur"], label="Select Effect"),
        gr.Slider(0.0, 1.0, value=0.5, label="Segmentation Threshold (for Gaussian Blur)"),
        gr.Slider(0, 30, value=15, step=1, label="Blur Intensity (for both effects)")
    ],
    outputs=gr.Image(type="pil", label="Output Image"),
    title="Interactive Blur Effects Demo",
    description=(
        "Upload an image and choose an effect. For 'Gaussian Blur Background', adjust the segmentation threshold and blur intensity. "
        "For 'Depth-based Lens Blur', the blur intensity slider sets the maximum blur based on depth."
    )
)

if __name__ == "__main__":
    iface.launch()