File size: 3,975 Bytes
0bc50d2
99416bb
 
 
0bc50d2
99416bb
0bc50d2
7ce2b4d
99416bb
 
 
 
 
0bc50d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99416bb
 
 
 
 
 
 
 
 
 
0bc50d2
 
99416bb
 
0bc50d2
 
 
 
 
 
 
 
 
 
99416bb
 
 
0bc50d2
99416bb
0bc50d2
 
99416bb
0bc50d2
 
 
 
99416bb
 
 
0bc50d2
99416bb
0bc50d2
 
 
 
 
 
 
 
 
 
99416bb
 
 
 
 
 
3b58019
 
 
 
99416bb
 
0bc50d2
99416bb
 
 
bd93c64
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from torchvision import transforms
from transformers import AutoProcessor, AutoModelForImageSegmentation, AutoModelForDepthEstimation

def load_segmentation_model():
    model_name = "ZhengPeng7/BiRefNet"
    model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True)
    return model

def load_depth_model():
    model_name = "depth-anything/Depth-Anything-V2-Metric-Indoor-Base-hf"
    processor = AutoProcessor.from_pretrained(model_name)
    model = AutoModelForDepthEstimation.from_pretrained(model_name)
    return processor, model

def process_segmentation_image(image):
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
    ])
    input_tensor = transform(image).unsqueeze(0)
    return image, input_tensor

def process_depth_image(image, processor):
    image = image.resize((512, 512))
    inputs = processor(images=image, return_tensors="pt")
    return image, inputs

def segment_image(image, input_tensor, model):
    with torch.no_grad():
        outputs = model(input_tensor)
        output_tensor = outputs[0] if isinstance(outputs, list) else outputs.logits
        mask = torch.sigmoid(output_tensor).squeeze().cpu().numpy()
        mask = (mask > 0.5).astype(np.uint8) * 255
    return mask

def estimate_depth(inputs, model):
    with torch.no_grad():
        outputs = model(**inputs)
    depth_map = outputs.predicted_depth.squeeze().cpu().numpy()
    return depth_map

def normalize_depth_map(depth_map):
    min_val = np.min(depth_map)
    max_val = np.max(depth_map)
    normalized_depth = (depth_map - min_val) / (max_val - min_val)
    return normalized_depth

def apply_blur(image, mask):
    mask_pil = Image.fromarray(mask).resize(image.size, Image.BILINEAR)
    blurred_background = image.filter(ImageFilter.GaussianBlur(15))
    final_image = Image.composite(image, blurred_background, mask_pil)
    return final_image

def apply_depth_based_blur(image, depth_map):
    normalized_depth = normalize_depth_map(depth_map)
    image = image.resize((512, 512))
    blurred_image = image.copy()
    for y in range(image.height):
        for x in range(image.width):
            depth_value = float(normalized_depth[y, x])
            blur_radius = max(0, depth_value * 20)
            cropped_region = image.crop((max(x-10, 0), max(y-10, 0), min(x+10, image.width), min(y+10, image.height)))
            blurred_region = cropped_region.filter(ImageFilter.GaussianBlur(blur_radius))
            blurred_image.paste(blurred_region, (max(x-10, 0), max(y-10, 0)))
    return blurred_image

def process_image_pipeline(image):
    segmentation_model = load_segmentation_model()
    depth_processor, depth_model = load_depth_model()
    
    _, input_tensor = process_segmentation_image(image)
    _, inputs = process_depth_image(image, depth_processor)
    
    segmentation_mask = segment_image(image, input_tensor, segmentation_model)
    depth_map = estimate_depth(inputs, depth_model)
    blurred_image = apply_depth_based_blur(image, depth_map)
    gaussian_blur_image = apply_blur(image, segmentation_mask)
    
    return image, Image.fromarray(segmentation_mask), Image.fromarray((depth_map / np.max(depth_map) * 255).astype(np.uint8)), blurred_image, gaussian_blur_image

iface = gr.Interface(
    fn=process_image_pipeline,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(label="Original Image"),
       # gr.Image(label="Segmentation Mask"),
       # gr.Image(label="Depth Map"),
        gr.Image(label="Lens Blur Effects"),
        gr.Image(label="Gaussian Blur Effects")
    ],
    title="Segmentation and Depth-Based Image Processing",
    description="Upload an image to get segmentation mask, depth map, depth-based blur effect, and Gaussian blur effect."
)

if __name__ == "__main__":
    iface.launch(share=True)