Ash2505 commited on
Commit
586457a
·
verified ·
1 Parent(s): e18a03c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -56
app.py CHANGED
@@ -43,11 +43,6 @@ depth_model.eval()
43
  # Define the Segmentation-Based Blur Effect
44
  # -----------------------------
45
  def segmentation_blur_effect(input_image: Image.Image):
46
- """
47
- Creates a segmentation mask using RMBG-2.0 and applies a Gaussian blur (sigma=15)
48
- to the background while keeping the foreground sharp.
49
- """
50
- # Resize input image for segmentation processing
51
  imageResized = input_image.resize(seg_image_size)
52
  input_tensor = seg_transform(imageResized).unsqueeze(0).to(device)
53
 
@@ -55,54 +50,28 @@ def segmentation_blur_effect(input_image: Image.Image):
55
  preds = seg_model(input_tensor)[-1].sigmoid().cpu()
56
  pred = preds[0].squeeze()
57
 
58
- # Convert predicted mask to a PIL image and ensure it matches imageResized's size
59
  pred_pil = transforms.ToPILImage()(pred)
60
  mask = pred_pil.resize(imageResized.size)
61
 
62
- # Convert mask to grayscale and threshold to create a binary mask
63
  mask_np = np.array(mask.convert("L"))
64
  _, maskBinary = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
65
 
66
- # Convert the resized image to an OpenCV BGR array
67
  img = cv2.cvtColor(np.array(imageResized), cv2.COLOR_RGB2BGR)
68
- # Apply Gaussian blur (sigmaX=15, sigmaY=15)
69
  blurredBg = cv2.GaussianBlur(np.array(imageResized), (0, 0), sigmaX=15, sigmaY=15)
70
 
71
- # Create the inverse mask and convert it to 3 channels
72
  maskInv = cv2.bitwise_not(maskBinary)
73
  maskInv3 = cv2.cvtColor(maskInv, cv2.COLOR_GRAY2BGR)
74
 
75
- # Extract the foreground and background using the mask
76
  foreground = cv2.bitwise_and(img, cv2.bitwise_not(maskInv3))
77
  background = cv2.bitwise_and(blurredBg, maskInv3)
78
 
79
- # Combine foreground and background; convert back to RGB for display
80
  finalImg = cv2.add(cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), background)
81
  finalImg_pil = Image.fromarray(finalImg)
82
 
83
  return finalImg_pil, mask
84
 
85
- # -----------------------------
86
- # Define the Depth-Based Lens Blur Effect with Slider-Controlled Thresholds
87
- # -----------------------------
88
  def lens_blur_effect(input_image: Image.Image, fg_threshold: float = 85, mg_threshold: float = 170):
89
- """
90
- Uses DepthPro to estimate a depth map and applies a dynamic lens blur effect
91
- by blending three versions of the image with increasing blur levels.
92
-
93
- Parameters:
94
- input_image: The original PIL image.
95
- fg_threshold: Foreground threshold (0-255). Pixels with depth below this are considered foreground.
96
- mg_threshold: Middleground threshold (0-255). Pixels with depth between fg_threshold and mg_threshold are middleground.
97
-
98
- Returns:
99
- depthImg: The computed depth map (PIL Image).
100
- lensBlurImage: The final lens-blurred image (PIL Image).
101
- mask_fg_img: Foreground depth mask.
102
- mask_mg_img: Middleground depth mask.
103
- mask_bg_img: Background depth mask.
104
- """
105
- # Process the image with the depth estimation model
106
  inputs = depth_processor(images=input_image, return_tensors="pt").to(device)
107
  with torch.no_grad():
108
  outputs = depth_model(**inputs)
@@ -111,39 +80,32 @@ def lens_blur_effect(input_image: Image.Image, fg_threshold: float = 85, mg_thre
111
  )
112
  depth = post_processed_output[0]["predicted_depth"]
113
 
114
- # Normalize depth to [0, 255]
115
  depth = (depth - depth.min()) / (depth.max() - depth.min())
116
  depth = depth * 255.
117
  depth = depth.detach().cpu().numpy()
118
  depth_map = depth.astype(np.uint8)
119
  depthImg = Image.fromarray(depth_map)
120
 
121
- # Convert input image to OpenCV BGR format
122
  img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
123
 
124
- # Precompute blurred versions for different depth regions
125
  img_foreground = img.copy() # No blur for foreground
126
  img_middleground = cv2.GaussianBlur(img, (0, 0), sigmaX=7, sigmaY=7)
127
  img_background = cv2.GaussianBlur(img, (0, 0), sigmaX=15, sigmaY=15)
128
 
129
  print(depth_map)
130
- depth_map /= depth_map.max()
131
 
132
- # Use slider values as thresholds
133
- threshold1 = fg_threshold # e.g., default 85
134
- threshold2 = mg_threshold # e.g., default 170
135
 
136
- # Create masks for foreground, middleground, and background based on depth
137
  mask_fg = (depth_map < threshold1).astype(np.float32)
138
  mask_mg = ((depth_map >= threshold1) & (depth_map < threshold2)).astype(np.float32)
139
  mask_bg = (depth_map >= threshold2).astype(np.float32)
140
 
141
- # Expand masks to 3 channels
142
  mask_fg_3 = np.stack([mask_fg]*3, axis=-1)
143
  mask_mg_3 = np.stack([mask_mg]*3, axis=-1)
144
  mask_bg_3 = np.stack([mask_bg]*3, axis=-1)
145
 
146
- # Blend the images using the masks
147
  final_img = (img_foreground * mask_fg_3 +
148
  img_middleground * mask_mg_3 +
149
  img_background * mask_bg_3).astype(np.uint8)
@@ -151,27 +113,14 @@ def lens_blur_effect(input_image: Image.Image, fg_threshold: float = 85, mg_thre
151
  final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
152
  lensBlurImage = Image.fromarray(final_img_rgb)
153
 
154
- # Create mask images for display (scaled to 0-255)
155
  mask_fg_img = Image.fromarray((mask_fg * 255).astype(np.uint8))
156
  mask_mg_img = Image.fromarray((mask_mg * 255).astype(np.uint8))
157
  mask_bg_img = Image.fromarray((mask_bg * 255).astype(np.uint8))
158
 
159
  return depthImg, lensBlurImage, mask_fg_img, mask_mg_img, mask_bg_img
160
 
161
- # -----------------------------
162
- # Gradio App: Process Image and Display Multiple Effects
163
- # -----------------------------
164
  def process_image(input_image: Image.Image, fg_threshold: float, mg_threshold: float):
165
- """
166
- Processes the uploaded image to generate:
167
- 1. Segmentation-based Gaussian blur effect.
168
- 2. Segmentation mask.
169
- 3. Depth map.
170
- 4. Depth-based lens blur effect.
171
- 5. Depth masks for foreground, middleground, and background.
172
-
173
- The depth thresholds for foreground and middleground regions are adjustable via sliders.
174
- """
175
  seg_blur, seg_mask = segmentation_blur_effect(input_image)
176
  depth_map_img, lens_blur_img, mask_fg_img, mask_mg_img, mask_bg_img = lens_blur_effect(
177
  input_image, fg_threshold, mg_threshold
 
43
  # Define the Segmentation-Based Blur Effect
44
  # -----------------------------
45
  def segmentation_blur_effect(input_image: Image.Image):
 
 
 
 
 
46
  imageResized = input_image.resize(seg_image_size)
47
  input_tensor = seg_transform(imageResized).unsqueeze(0).to(device)
48
 
 
50
  preds = seg_model(input_tensor)[-1].sigmoid().cpu()
51
  pred = preds[0].squeeze()
52
 
 
53
  pred_pil = transforms.ToPILImage()(pred)
54
  mask = pred_pil.resize(imageResized.size)
55
 
 
56
  mask_np = np.array(mask.convert("L"))
57
  _, maskBinary = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
58
 
 
59
  img = cv2.cvtColor(np.array(imageResized), cv2.COLOR_RGB2BGR)
 
60
  blurredBg = cv2.GaussianBlur(np.array(imageResized), (0, 0), sigmaX=15, sigmaY=15)
61
 
 
62
  maskInv = cv2.bitwise_not(maskBinary)
63
  maskInv3 = cv2.cvtColor(maskInv, cv2.COLOR_GRAY2BGR)
64
 
 
65
  foreground = cv2.bitwise_and(img, cv2.bitwise_not(maskInv3))
66
  background = cv2.bitwise_and(blurredBg, maskInv3)
67
 
 
68
  finalImg = cv2.add(cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), background)
69
  finalImg_pil = Image.fromarray(finalImg)
70
 
71
  return finalImg_pil, mask
72
 
 
 
 
73
  def lens_blur_effect(input_image: Image.Image, fg_threshold: float = 85, mg_threshold: float = 170):
74
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  inputs = depth_processor(images=input_image, return_tensors="pt").to(device)
76
  with torch.no_grad():
77
  outputs = depth_model(**inputs)
 
80
  )
81
  depth = post_processed_output[0]["predicted_depth"]
82
 
 
83
  depth = (depth - depth.min()) / (depth.max() - depth.min())
84
  depth = depth * 255.
85
  depth = depth.detach().cpu().numpy()
86
  depth_map = depth.astype(np.uint8)
87
  depthImg = Image.fromarray(depth_map)
88
 
 
89
  img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
90
 
 
91
  img_foreground = img.copy() # No blur for foreground
92
  img_middleground = cv2.GaussianBlur(img, (0, 0), sigmaX=7, sigmaY=7)
93
  img_background = cv2.GaussianBlur(img, (0, 0), sigmaX=15, sigmaY=15)
94
 
95
  print(depth_map)
96
+ depth_map = depth_map.astype(np.float32) / depth_map.max()
97
 
98
+ threshold1 = fg_threshold
99
+ threshold2 = mg_threshold
 
100
 
 
101
  mask_fg = (depth_map < threshold1).astype(np.float32)
102
  mask_mg = ((depth_map >= threshold1) & (depth_map < threshold2)).astype(np.float32)
103
  mask_bg = (depth_map >= threshold2).astype(np.float32)
104
 
 
105
  mask_fg_3 = np.stack([mask_fg]*3, axis=-1)
106
  mask_mg_3 = np.stack([mask_mg]*3, axis=-1)
107
  mask_bg_3 = np.stack([mask_bg]*3, axis=-1)
108
 
 
109
  final_img = (img_foreground * mask_fg_3 +
110
  img_middleground * mask_mg_3 +
111
  img_background * mask_bg_3).astype(np.uint8)
 
113
  final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
114
  lensBlurImage = Image.fromarray(final_img_rgb)
115
 
 
116
  mask_fg_img = Image.fromarray((mask_fg * 255).astype(np.uint8))
117
  mask_mg_img = Image.fromarray((mask_mg * 255).astype(np.uint8))
118
  mask_bg_img = Image.fromarray((mask_bg * 255).astype(np.uint8))
119
 
120
  return depthImg, lensBlurImage, mask_fg_img, mask_mg_img, mask_bg_img
121
 
 
 
 
122
  def process_image(input_image: Image.Image, fg_threshold: float, mg_threshold: float):
123
+
 
 
 
 
 
 
 
 
 
124
  seg_blur, seg_mask = segmentation_blur_effect(input_image)
125
  depth_map_img, lens_blur_img, mask_fg_img, mask_mg_img, mask_bg_img = lens_blur_effect(
126
  input_image, fg_threshold, mg_threshold