amoghrrao commited on
Commit
0bc50d2
·
verified ·
1 Parent(s): 404b99b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -46
app.py CHANGED
@@ -1,22 +1,36 @@
 
1
  import torch
2
  import numpy as np
3
  from PIL import Image, ImageFilter
4
- import gradio as gr
5
  from torchvision import transforms
6
- from transformers import (
7
- AutoModelForImageSegmentation,
8
- AutoProcessor,
9
- AutoModelForDepthEstimation,
10
- )
11
-
12
-
13
 
14
  def load_segmentation_model():
15
  model_name = "ZhengPeng7/BiRefNet"
16
  model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True)
17
  return model
18
 
19
- def segment_image(input_tensor, model):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  with torch.no_grad():
21
  outputs = model(input_tensor)
22
  output_tensor = outputs[0] if isinstance(outputs, list) else outputs.logits
@@ -24,54 +38,50 @@ def segment_image(input_tensor, model):
24
  mask = (mask > 0.5).astype(np.uint8) * 255
25
  return mask
26
 
27
- def load_depth_model():
28
- model_name = "depth-anything/Depth-Anything-V2-Metric-Indoor-Base-hf"
29
- processor = AutoProcessor.from_pretrained(model_name)
30
- model = AutoModelForDepthEstimation.from_pretrained(model_name)
31
- return processor, model
32
-
33
  def estimate_depth(inputs, model):
34
  with torch.no_grad():
35
  outputs = model(**inputs)
36
- return outputs.predicted_depth.squeeze().cpu().numpy()
 
37
 
38
  def normalize_depth_map(depth_map):
39
- return (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
 
 
 
 
 
 
 
 
 
40
 
41
  def apply_depth_based_blur(image, depth_map):
42
  normalized_depth = normalize_depth_map(depth_map)
 
43
  blurred_image = image.copy()
44
-
45
- for y in range(0, image.height, 20):
46
- for x in range(0, image.width, 20):
47
  depth_value = float(normalized_depth[y, x])
48
- blur_radius = depth_value * 20
49
-
50
- box = (max(x-10, 0), max(y-10, 0), min(x+10, image.width), min(y+10, image.height))
51
- cropped = image.crop(box)
52
- blurred_region = cropped.filter(ImageFilter.GaussianBlur(blur_radius))
53
- blurred_image.paste(blurred_region, box)
54
-
55
  return blurred_image
56
 
57
  def process_image_pipeline(image):
58
- original_image = image.convert("RGB").resize((512, 512))
59
-
60
- # Segmentation
61
- seg_transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])
62
- input_tensor = seg_transform(original_image).unsqueeze(0)
63
- seg_model = load_segmentation_model()
64
- mask = segment_image(input_tensor, seg_model)
65
-
66
- # Depth Estimation
67
  depth_processor, depth_model = load_depth_model()
68
- depth_inputs = depth_processor(images=original_image, return_tensors="pt")
69
- depth_map = estimate_depth(depth_inputs, depth_model)
70
-
71
- # Depth-based Blur
72
- blurred_image = apply_depth_based_blur(original_image, depth_map)
73
-
74
- return original_image, Image.fromarray(mask), Image.fromarray(np.uint8(depth_map / depth_map.max() * 255)), blurred_image
 
 
 
75
 
76
  iface = gr.Interface(
77
  fn=process_image_pipeline,
@@ -80,11 +90,12 @@ iface = gr.Interface(
80
  gr.Image(label="Original Image"),
81
  gr.Image(label="Segmentation Mask"),
82
  gr.Image(label="Depth Map"),
83
- gr.Image(label="Depth-based Blurred Image")
 
84
  ],
85
  title="Segmentation and Depth-Based Image Processing",
86
- description="Upload an image to get segmentation mask, depth map, and depth-based blur effect.",
87
  )
88
 
89
  if __name__ == "__main__":
90
- iface.launch()
 
1
+ import gradio as gr
2
  import torch
3
  import numpy as np
4
  from PIL import Image, ImageFilter
5
+ import matplotlib.pyplot as plt
6
  from torchvision import transforms
7
+ from transformers import AutoProcessor, AutoModelForImageSegmentation, AutoModelForDepthEstimation
 
 
 
 
 
 
8
 
9
  def load_segmentation_model():
10
  model_name = "ZhengPeng7/BiRefNet"
11
  model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True)
12
  return model
13
 
14
+ def load_depth_model():
15
+ model_name = "depth-anything/Depth-Anything-V2-Metric-Indoor-Base-hf"
16
+ processor = AutoProcessor.from_pretrained(model_name)
17
+ model = AutoModelForDepthEstimation.from_pretrained(model_name)
18
+ return processor, model
19
+
20
+ def process_segmentation_image(image):
21
+ transform = transforms.Compose([
22
+ transforms.Resize((512, 512)),
23
+ transforms.ToTensor(),
24
+ ])
25
+ input_tensor = transform(image).unsqueeze(0)
26
+ return image, input_tensor
27
+
28
+ def process_depth_image(image, processor):
29
+ image = image.resize((512, 512))
30
+ inputs = processor(images=image, return_tensors="pt")
31
+ return image, inputs
32
+
33
+ def segment_image(image, input_tensor, model):
34
  with torch.no_grad():
35
  outputs = model(input_tensor)
36
  output_tensor = outputs[0] if isinstance(outputs, list) else outputs.logits
 
38
  mask = (mask > 0.5).astype(np.uint8) * 255
39
  return mask
40
 
 
 
 
 
 
 
41
  def estimate_depth(inputs, model):
42
  with torch.no_grad():
43
  outputs = model(**inputs)
44
+ depth_map = outputs.predicted_depth.squeeze().cpu().numpy()
45
+ return depth_map
46
 
47
  def normalize_depth_map(depth_map):
48
+ min_val = np.min(depth_map)
49
+ max_val = np.max(depth_map)
50
+ normalized_depth = (depth_map - min_val) / (max_val - min_val)
51
+ return normalized_depth
52
+
53
+ def apply_blur(image, mask):
54
+ mask_pil = Image.fromarray(mask).resize(image.size, Image.BILINEAR)
55
+ blurred_background = image.filter(ImageFilter.GaussianBlur(15))
56
+ final_image = Image.composite(image, blurred_background, mask_pil)
57
+ return final_image
58
 
59
  def apply_depth_based_blur(image, depth_map):
60
  normalized_depth = normalize_depth_map(depth_map)
61
+ image = image.resize((512, 512))
62
  blurred_image = image.copy()
63
+ for y in range(image.height):
64
+ for x in range(image.width):
 
65
  depth_value = float(normalized_depth[y, x])
66
+ blur_radius = max(0, depth_value * 20)
67
+ cropped_region = image.crop((max(x-10, 0), max(y-10, 0), min(x+10, image.width), min(y+10, image.height)))
68
+ blurred_region = cropped_region.filter(ImageFilter.GaussianBlur(blur_radius))
69
+ blurred_image.paste(blurred_region, (max(x-10, 0), max(y-10, 0)))
 
 
 
70
  return blurred_image
71
 
72
  def process_image_pipeline(image):
73
+ segmentation_model = load_segmentation_model()
 
 
 
 
 
 
 
 
74
  depth_processor, depth_model = load_depth_model()
75
+
76
+ _, input_tensor = process_segmentation_image(image)
77
+ _, inputs = process_depth_image(image, depth_processor)
78
+
79
+ segmentation_mask = segment_image(image, input_tensor, segmentation_model)
80
+ depth_map = estimate_depth(inputs, depth_model)
81
+ blurred_image = apply_depth_based_blur(image, depth_map)
82
+ gaussian_blur_image = apply_blur(image, segmentation_mask)
83
+
84
+ return image, Image.fromarray(segmentation_mask), Image.fromarray((depth_map / np.max(depth_map) * 255).astype(np.uint8)), blurred_image, gaussian_blur_image
85
 
86
  iface = gr.Interface(
87
  fn=process_image_pipeline,
 
90
  gr.Image(label="Original Image"),
91
  gr.Image(label="Segmentation Mask"),
92
  gr.Image(label="Depth Map"),
93
+ gr.Image(label="Depth-based Blurred Image"),
94
+ gr.Image(label="Gaussian Blur Image")
95
  ],
96
  title="Segmentation and Depth-Based Image Processing",
97
+ description="Upload an image to get segmentation mask, depth map, depth-based blur effect, and Gaussian blur effect."
98
  )
99
 
100
  if __name__ == "__main__":
101
+ iface.launch()