Update app.py
Browse files
app.py
CHANGED
@@ -1,22 +1,36 @@
|
|
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
from PIL import Image, ImageFilter
|
4 |
-
import
|
5 |
from torchvision import transforms
|
6 |
-
from transformers import
|
7 |
-
AutoModelForImageSegmentation,
|
8 |
-
AutoProcessor,
|
9 |
-
AutoModelForDepthEstimation,
|
10 |
-
)
|
11 |
-
|
12 |
-
|
13 |
|
14 |
def load_segmentation_model():
|
15 |
model_name = "ZhengPeng7/BiRefNet"
|
16 |
model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True)
|
17 |
return model
|
18 |
|
19 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
with torch.no_grad():
|
21 |
outputs = model(input_tensor)
|
22 |
output_tensor = outputs[0] if isinstance(outputs, list) else outputs.logits
|
@@ -24,54 +38,50 @@ def segment_image(input_tensor, model):
|
|
24 |
mask = (mask > 0.5).astype(np.uint8) * 255
|
25 |
return mask
|
26 |
|
27 |
-
def load_depth_model():
|
28 |
-
model_name = "depth-anything/Depth-Anything-V2-Metric-Indoor-Base-hf"
|
29 |
-
processor = AutoProcessor.from_pretrained(model_name)
|
30 |
-
model = AutoModelForDepthEstimation.from_pretrained(model_name)
|
31 |
-
return processor, model
|
32 |
-
|
33 |
def estimate_depth(inputs, model):
|
34 |
with torch.no_grad():
|
35 |
outputs = model(**inputs)
|
36 |
-
|
|
|
37 |
|
38 |
def normalize_depth_map(depth_map):
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
def apply_depth_based_blur(image, depth_map):
|
42 |
normalized_depth = normalize_depth_map(depth_map)
|
|
|
43 |
blurred_image = image.copy()
|
44 |
-
|
45 |
-
|
46 |
-
for x in range(0, image.width, 20):
|
47 |
depth_value = float(normalized_depth[y, x])
|
48 |
-
blur_radius = depth_value * 20
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
blurred_region = cropped.filter(ImageFilter.GaussianBlur(blur_radius))
|
53 |
-
blurred_image.paste(blurred_region, box)
|
54 |
-
|
55 |
return blurred_image
|
56 |
|
57 |
def process_image_pipeline(image):
|
58 |
-
|
59 |
-
|
60 |
-
# Segmentation
|
61 |
-
seg_transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])
|
62 |
-
input_tensor = seg_transform(original_image).unsqueeze(0)
|
63 |
-
seg_model = load_segmentation_model()
|
64 |
-
mask = segment_image(input_tensor, seg_model)
|
65 |
-
|
66 |
-
# Depth Estimation
|
67 |
depth_processor, depth_model = load_depth_model()
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
75 |
|
76 |
iface = gr.Interface(
|
77 |
fn=process_image_pipeline,
|
@@ -80,11 +90,12 @@ iface = gr.Interface(
|
|
80 |
gr.Image(label="Original Image"),
|
81 |
gr.Image(label="Segmentation Mask"),
|
82 |
gr.Image(label="Depth Map"),
|
83 |
-
gr.Image(label="Depth-based Blurred Image")
|
|
|
84 |
],
|
85 |
title="Segmentation and Depth-Based Image Processing",
|
86 |
-
description="Upload an image to get segmentation mask, depth map,
|
87 |
)
|
88 |
|
89 |
if __name__ == "__main__":
|
90 |
-
iface.launch()
|
|
|
1 |
+
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
from PIL import Image, ImageFilter
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
from torchvision import transforms
|
7 |
+
from transformers import AutoProcessor, AutoModelForImageSegmentation, AutoModelForDepthEstimation
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def load_segmentation_model():
|
10 |
model_name = "ZhengPeng7/BiRefNet"
|
11 |
model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True)
|
12 |
return model
|
13 |
|
14 |
+
def load_depth_model():
|
15 |
+
model_name = "depth-anything/Depth-Anything-V2-Metric-Indoor-Base-hf"
|
16 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
17 |
+
model = AutoModelForDepthEstimation.from_pretrained(model_name)
|
18 |
+
return processor, model
|
19 |
+
|
20 |
+
def process_segmentation_image(image):
|
21 |
+
transform = transforms.Compose([
|
22 |
+
transforms.Resize((512, 512)),
|
23 |
+
transforms.ToTensor(),
|
24 |
+
])
|
25 |
+
input_tensor = transform(image).unsqueeze(0)
|
26 |
+
return image, input_tensor
|
27 |
+
|
28 |
+
def process_depth_image(image, processor):
|
29 |
+
image = image.resize((512, 512))
|
30 |
+
inputs = processor(images=image, return_tensors="pt")
|
31 |
+
return image, inputs
|
32 |
+
|
33 |
+
def segment_image(image, input_tensor, model):
|
34 |
with torch.no_grad():
|
35 |
outputs = model(input_tensor)
|
36 |
output_tensor = outputs[0] if isinstance(outputs, list) else outputs.logits
|
|
|
38 |
mask = (mask > 0.5).astype(np.uint8) * 255
|
39 |
return mask
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
def estimate_depth(inputs, model):
|
42 |
with torch.no_grad():
|
43 |
outputs = model(**inputs)
|
44 |
+
depth_map = outputs.predicted_depth.squeeze().cpu().numpy()
|
45 |
+
return depth_map
|
46 |
|
47 |
def normalize_depth_map(depth_map):
|
48 |
+
min_val = np.min(depth_map)
|
49 |
+
max_val = np.max(depth_map)
|
50 |
+
normalized_depth = (depth_map - min_val) / (max_val - min_val)
|
51 |
+
return normalized_depth
|
52 |
+
|
53 |
+
def apply_blur(image, mask):
|
54 |
+
mask_pil = Image.fromarray(mask).resize(image.size, Image.BILINEAR)
|
55 |
+
blurred_background = image.filter(ImageFilter.GaussianBlur(15))
|
56 |
+
final_image = Image.composite(image, blurred_background, mask_pil)
|
57 |
+
return final_image
|
58 |
|
59 |
def apply_depth_based_blur(image, depth_map):
|
60 |
normalized_depth = normalize_depth_map(depth_map)
|
61 |
+
image = image.resize((512, 512))
|
62 |
blurred_image = image.copy()
|
63 |
+
for y in range(image.height):
|
64 |
+
for x in range(image.width):
|
|
|
65 |
depth_value = float(normalized_depth[y, x])
|
66 |
+
blur_radius = max(0, depth_value * 20)
|
67 |
+
cropped_region = image.crop((max(x-10, 0), max(y-10, 0), min(x+10, image.width), min(y+10, image.height)))
|
68 |
+
blurred_region = cropped_region.filter(ImageFilter.GaussianBlur(blur_radius))
|
69 |
+
blurred_image.paste(blurred_region, (max(x-10, 0), max(y-10, 0)))
|
|
|
|
|
|
|
70 |
return blurred_image
|
71 |
|
72 |
def process_image_pipeline(image):
|
73 |
+
segmentation_model = load_segmentation_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
depth_processor, depth_model = load_depth_model()
|
75 |
+
|
76 |
+
_, input_tensor = process_segmentation_image(image)
|
77 |
+
_, inputs = process_depth_image(image, depth_processor)
|
78 |
+
|
79 |
+
segmentation_mask = segment_image(image, input_tensor, segmentation_model)
|
80 |
+
depth_map = estimate_depth(inputs, depth_model)
|
81 |
+
blurred_image = apply_depth_based_blur(image, depth_map)
|
82 |
+
gaussian_blur_image = apply_blur(image, segmentation_mask)
|
83 |
+
|
84 |
+
return image, Image.fromarray(segmentation_mask), Image.fromarray((depth_map / np.max(depth_map) * 255).astype(np.uint8)), blurred_image, gaussian_blur_image
|
85 |
|
86 |
iface = gr.Interface(
|
87 |
fn=process_image_pipeline,
|
|
|
90 |
gr.Image(label="Original Image"),
|
91 |
gr.Image(label="Segmentation Mask"),
|
92 |
gr.Image(label="Depth Map"),
|
93 |
+
gr.Image(label="Depth-based Blurred Image"),
|
94 |
+
gr.Image(label="Gaussian Blur Image")
|
95 |
],
|
96 |
title="Segmentation and Depth-Based Image Processing",
|
97 |
+
description="Upload an image to get segmentation mask, depth map, depth-based blur effect, and Gaussian blur effect."
|
98 |
)
|
99 |
|
100 |
if __name__ == "__main__":
|
101 |
+
iface.launch()
|