File size: 6,168 Bytes
4b0b756
212a439
4b0b756
 
 
212a439
4b0b756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae8d774
4b0b756
 
 
 
 
 
 
 
 
 
 
 
 
 
ae8d774
4b0b756
ae8d774
 
4b0b756
 
 
 
ae8d774
4b0b756
 
 
 
 
212a439
ae8d774
4b0b756
 
 
 
212a439
ae8d774
4b0b756
 
212a439
4b0b756
 
ae8d774
 
 
4b0b756
ae8d774
 
4b0b756
ae8d774
 
4b0b756
 
 
 
 
 
 
 
ae8d774
4b0b756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae8d774
4b0b756
ae8d774
 
a390814
4b0b756
212a439
ae8d774
212a439
ae8d774
212a439
4b0b756
 
 
ae8d774
4b0b756
212a439
ae8d774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b0b756
212a439
 
ae8d774
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import io
import numpy as np
import torch
from PIL import Image, ImageFilter
from torchvision import transforms
import gradio as gr
from transformers import AutoModelForImageSegmentation, pipeline

# ----------------------------
# Global Setup and Model Loading
# ----------------------------

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the segmentation model (RMBG-2.0)
segmentation_model = AutoModelForImageSegmentation.from_pretrained(
    'briaai/RMBG-2.0',
    trust_remote_code=True
)
segmentation_model.to(device)
segmentation_model.eval()

# Transformation for segmentation (resizes to 512 for the model input)
image_size = (512, 512)
segmentation_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the depth estimation pipeline (Depth-Anything)
depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")

# ----------------------------
# Processing Functions
# ----------------------------

def segment_and_blur_background(input_image: Image.Image, blur_strength: int = 15, threshold: float = 0.5) -> Image.Image:
    """
    Applies segmentation using the RMBG-2.0 model and composites the original image with
    a Gaussian-blurred background based on an adjustable mask sensitivity threshold.
    """
    image = input_image.convert("RGB")
    orig_width, orig_height = image.size

    # Preprocess image for segmentation (resize only for model inference)
    input_tensor = segmentation_transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        preds = segmentation_model(input_tensor)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    
    # Create binary mask with adjustable threshold (mask sensitivity)
    binary_mask = (pred > threshold).float()
    mask_pil = transforms.ToPILImage()(binary_mask).convert("L")
    mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0)
    mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR)
    
    blurred_image = image.filter(ImageFilter.GaussianBlur(blur_strength))
    final_image = Image.composite(image, blurred_image, mask_pil)
    return final_image

def depth_based_lens_blur(input_image: Image.Image, max_blur: float = 2, num_bands: int = 40, invert_depth: bool = False) -> Image.Image:
    """
    Applies a depth-based blur effect using a depth map produced by Depth-Anything.
    The effect simulates a lens blur where the max_blur parameter controls the maximum blur.
    This function uses the original input image size.
    """
    # Use the original image for depth estimation (no resizing)
    image_original = input_image.convert("RGB")
    
    # Obtain depth map using the pipeline (assumes model accepts variable sizes)
    results = depth_pipeline(image_original)
    depth_map_image = results['depth']
    
    depth_array = np.array(depth_map_image, dtype=np.float32)
    d_min, d_max = depth_array.min(), depth_array.max()
    depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8)
    if invert_depth:
        depth_norm = 1.0 - depth_norm

    orig_rgba = image_original.convert("RGBA")
    final_image = orig_rgba.copy()
    
    band_edges = np.linspace(0, 1, num_bands + 1)
    for i in range(num_bands):
        band_min = band_edges[i]
        band_max = band_edges[i + 1]
        mid = (band_min + band_max) / 2.0
        blur_radius_band = (1 - mid) * max_blur
        
        blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius_band))
        band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255
        band_mask_pil = Image.fromarray(band_mask, mode="L")
        final_image = Image.composite(blurred_version, final_image, band_mask_pil)
    
    return final_image.convert("RGB")

def process_image(input_image: Image.Image, effect: str, mask_sensitivity: float, blur_strength: float) -> Image.Image:
    """
    Applies the selected effect:
      - "Gaussian Blur Background": uses segmentation with adjustable mask sensitivity and blur strength.
      - "Depth-based Lens Blur": applies depth-based blur with an adjustable maximum blur.
    """
    if effect == "Gaussian Blur Background":
        return segment_and_blur_background(input_image, blur_strength=int(blur_strength), threshold=mask_sensitivity)
    elif effect == "Depth-based Lens Blur":
        return depth_based_lens_blur(input_image, max_blur=blur_strength)
    else:
        return input_image

# ----------------------------
# Gradio Blocks Layout
# ----------------------------

with gr.Blocks(title="Interactive Blur Effects Demo") as demo:
    gr.Markdown(
        """
        # Interactive Blur Effects Demo
        Upload an image and choose an effect below.  
        For **Gaussian Blur Background**, adjust the mask sensitivity (controls segmentation threshold) 
        and blur strength (controls Gaussian blur radius).  
        For **Depth-based Lens Blur**, the blur strength slider sets the maximum blur intensity.
        """
    )
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Input Image")
            effect_choice = gr.Radio(
                choices=["Gaussian Blur Background", "Depth-based Lens Blur"],
                label="Select Effect",
                value="Gaussian Blur Background"
            )
            mask_sensitivity_slider = gr.Slider(
                minimum=0.0, maximum=1.0, value=0.5, step=0.01,
                label="Mask Sensitivity (for segmentation)"
            )
            blur_strength_slider = gr.Slider(
                minimum=0, maximum=30, value=15, step=1,
                label="Blur Strength"
            )
            run_button = gr.Button("Apply Effect")
        with gr.Column():
            output_image = gr.Image(type="pil", label="Output Image")
    
    run_button.click(
        fn=process_image,
        inputs=[input_image, effect_choice, mask_sensitivity_slider, blur_strength_slider],
        outputs=output_image
    )

if __name__ == "__main__":
    demo.launch()