Ash2505 commited on
Commit
9b7d147
·
verified ·
1 Parent(s): e28b51d

Resize changes

Browse files
Files changed (1) hide show
  1. app.py +19 -31
app.py CHANGED
@@ -19,7 +19,6 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
19
  seg_model = AutoModelForImageSegmentation.from_pretrained(
20
  "briaai/RMBG-2.0", trust_remote_code=True
21
  )
22
- # Set higher precision for matmul if desired
23
  torch.set_float32_matmul_precision(["high", "highest"][0])
24
  seg_model.to(device)
25
  seg_model.eval()
@@ -47,13 +46,8 @@ def segmentation_blur_effect(input_image: Image.Image):
47
  """
48
  Creates a segmentation mask using RMBG-2.0 and applies a Gaussian blur (sigma=15)
49
  to the background while keeping the foreground sharp.
50
-
51
- Returns:
52
- - final segmented and blurred image (PIL Image)
53
- - segmentation mask (PIL Image)
54
- - blurred background image (PIL Image) [optional display]
55
  """
56
- # Resize input for segmentation processing
57
  imageResized = input_image.resize(seg_image_size)
58
  input_tensor = seg_transform(imageResized).unsqueeze(0).to(device)
59
 
@@ -61,11 +55,12 @@ def segmentation_blur_effect(input_image: Image.Image):
61
  preds = seg_model(input_tensor)[-1].sigmoid().cpu()
62
  pred = preds[0].squeeze()
63
 
64
- # Convert predicted mask to a PIL image and resize to original input size
65
  pred_pil = transforms.ToPILImage()(pred)
66
- mask = pred_pil.resize(input_image.size)
 
67
 
68
- # Create a binary mask (convert to grayscale, then threshold)
69
  mask_np = np.array(mask.convert("L"))
70
  _, maskBinary = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
71
 
@@ -74,20 +69,19 @@ def segmentation_blur_effect(input_image: Image.Image):
74
  # Apply Gaussian blur (sigmaX=15, sigmaY=15)
75
  blurredBg = cv2.GaussianBlur(np.array(imageResized), (0, 0), sigmaX=15, sigmaY=15)
76
 
77
- # Create the inverse mask and convert to 3 channels
78
  maskInv = cv2.bitwise_not(maskBinary)
79
  maskInv3 = cv2.cvtColor(maskInv, cv2.COLOR_GRAY2BGR)
80
 
81
- # Extract the foreground and background separately
82
  foreground = cv2.bitwise_and(img, cv2.bitwise_not(maskInv3))
83
  background = cv2.bitwise_and(blurredBg, maskInv3)
84
 
85
- # Combine the two components
86
  finalImg = cv2.add(cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), background)
87
  finalImg_pil = Image.fromarray(finalImg)
88
- blurredBg_pil = Image.fromarray(cv2.cvtColor(blurredBg, cv2.COLOR_BGR2RGB))
89
 
90
- return finalImg_pil, mask, blurredBg_pil
91
 
92
  # -----------------------------
93
  # Define the Depth-Based Lens Blur Effect
@@ -95,15 +89,9 @@ def segmentation_blur_effect(input_image: Image.Image):
95
  def lens_blur_effect(input_image: Image.Image):
96
  """
97
  Uses DepthPro to estimate a depth map and applies a dynamic lens blur effect
98
- by precomputing three versions of the image (foreground, middleground, background)
99
- with increasing blur. Regions are blended based on the estimated depth.
100
-
101
- Returns:
102
- - Depth map (PIL Image)
103
- - Final lens-blurred image (PIL Image)
104
- - Foreground mask (PIL Image)
105
- - Middleground mask (PIL Image)
106
- - Background mask (PIL Image)
107
  """
108
  # Process the image with the depth estimation model
109
  inputs = depth_processor(images=input_image, return_tensors="pt").to(device)
@@ -124,7 +112,7 @@ def lens_blur_effect(input_image: Image.Image):
124
  # Convert input image to OpenCV BGR format
125
  img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
126
 
127
- # Precompute three blurred versions of the image
128
  img_foreground = img.copy() # No blur for foreground
129
  img_middleground = cv2.GaussianBlur(img, (0, 0), sigmaX=7, sigmaY=7)
130
  img_background = cv2.GaussianBlur(img, (0, 0), sigmaX=15, sigmaY=15)
@@ -133,17 +121,17 @@ def lens_blur_effect(input_image: Image.Image):
133
  threshold1 = 255 / 3 # ~85
134
  threshold2 = 2 * 255 / 3 # ~170
135
 
136
- # Create masks for the three regions based on depth
137
  mask_fg = (depth_map < threshold1).astype(np.float32)
138
  mask_mg = ((depth_map >= threshold1) & (depth_map < threshold2)).astype(np.float32)
139
  mask_bg = (depth_map >= threshold2).astype(np.float32)
140
 
141
- # Expand masks to 3 channels to match image dimensions
142
  mask_fg_3 = np.stack([mask_fg]*3, axis=-1)
143
  mask_mg_3 = np.stack([mask_mg]*3, axis=-1)
144
  mask_bg_3 = np.stack([mask_bg]*3, axis=-1)
145
 
146
- # Combine the images using the masks (vectorized blending)
147
  final_img = (img_foreground * mask_fg_3 +
148
  img_middleground * mask_mg_3 +
149
  img_background * mask_bg_3).astype(np.uint8)
@@ -151,7 +139,7 @@ def lens_blur_effect(input_image: Image.Image):
151
  final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
152
  lensBlurImage = Image.fromarray(final_img_rgb)
153
 
154
- # Create mask images (scaled to 0-255)
155
  mask_fg_img = Image.fromarray((mask_fg * 255).astype(np.uint8))
156
  mask_mg_img = Image.fromarray((mask_mg * 255).astype(np.uint8))
157
  mask_bg_img = Image.fromarray((mask_bg * 255).astype(np.uint8))
@@ -170,7 +158,7 @@ def process_image(input_image: Image.Image):
170
  4. Depth-based lens blur effect.
171
  5. Depth-based masks for foreground, middleground, and background.
172
  """
173
- seg_blur, seg_mask, _ = segmentation_blur_effect(input_image)
174
  depth_map_img, lens_blur_img, mask_fg_img, mask_mg_img, mask_bg_img = lens_blur_effect(input_image)
175
 
176
  return (
@@ -188,7 +176,7 @@ description = (
188
  "Upload an image to apply two distinct effects:\n\n"
189
  "1. A segmentation-based Gaussian blur that blurs the background (using RMBG-2.0).\n"
190
  "2. A depth-based lens blur effect that simulates realistic lens blur based on depth (using DepthPro).\n\n"
191
- "Outputs include the blurred image, segmentation mask, depth map, lens-blurred image, and depth masks."
192
  )
193
 
194
  demo = gr.Interface(
 
19
  seg_model = AutoModelForImageSegmentation.from_pretrained(
20
  "briaai/RMBG-2.0", trust_remote_code=True
21
  )
 
22
  torch.set_float32_matmul_precision(["high", "highest"][0])
23
  seg_model.to(device)
24
  seg_model.eval()
 
46
  """
47
  Creates a segmentation mask using RMBG-2.0 and applies a Gaussian blur (sigma=15)
48
  to the background while keeping the foreground sharp.
 
 
 
 
 
49
  """
50
+ # Resize input image for segmentation processing
51
  imageResized = input_image.resize(seg_image_size)
52
  input_tensor = seg_transform(imageResized).unsqueeze(0).to(device)
53
 
 
55
  preds = seg_model(input_tensor)[-1].sigmoid().cpu()
56
  pred = preds[0].squeeze()
57
 
58
+ # Convert predicted mask to a PIL image and ensure it matches imageResized's size
59
  pred_pil = transforms.ToPILImage()(pred)
60
+ # Resize mask to match imageResized to avoid size mismatch in OpenCV operations
61
+ mask = pred_pil.resize(imageResized.size)
62
 
63
+ # Convert mask to grayscale and threshold to create a binary mask
64
  mask_np = np.array(mask.convert("L"))
65
  _, maskBinary = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
66
 
 
69
  # Apply Gaussian blur (sigmaX=15, sigmaY=15)
70
  blurredBg = cv2.GaussianBlur(np.array(imageResized), (0, 0), sigmaX=15, sigmaY=15)
71
 
72
+ # Create the inverse mask and convert it to 3 channels
73
  maskInv = cv2.bitwise_not(maskBinary)
74
  maskInv3 = cv2.cvtColor(maskInv, cv2.COLOR_GRAY2BGR)
75
 
76
+ # Extract the foreground and background using the mask
77
  foreground = cv2.bitwise_and(img, cv2.bitwise_not(maskInv3))
78
  background = cv2.bitwise_and(blurredBg, maskInv3)
79
 
80
+ # Combine foreground and background; convert back to RGB for display
81
  finalImg = cv2.add(cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), background)
82
  finalImg_pil = Image.fromarray(finalImg)
 
83
 
84
+ return finalImg_pil, mask
85
 
86
  # -----------------------------
87
  # Define the Depth-Based Lens Blur Effect
 
89
  def lens_blur_effect(input_image: Image.Image):
90
  """
91
  Uses DepthPro to estimate a depth map and applies a dynamic lens blur effect
92
+ by blending three versions of the image (foreground, middleground, background)
93
+ with increasing blur levels. Returns the depth map, the final lens-blurred image,
94
+ and the depth masks.
 
 
 
 
 
 
95
  """
96
  # Process the image with the depth estimation model
97
  inputs = depth_processor(images=input_image, return_tensors="pt").to(device)
 
112
  # Convert input image to OpenCV BGR format
113
  img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
114
 
115
+ # Precompute blurred versions for different depth regions
116
  img_foreground = img.copy() # No blur for foreground
117
  img_middleground = cv2.GaussianBlur(img, (0, 0), sigmaX=7, sigmaY=7)
118
  img_background = cv2.GaussianBlur(img, (0, 0), sigmaX=15, sigmaY=15)
 
121
  threshold1 = 255 / 3 # ~85
122
  threshold2 = 2 * 255 / 3 # ~170
123
 
124
+ # Create masks for foreground, middleground, and background based on depth
125
  mask_fg = (depth_map < threshold1).astype(np.float32)
126
  mask_mg = ((depth_map >= threshold1) & (depth_map < threshold2)).astype(np.float32)
127
  mask_bg = (depth_map >= threshold2).astype(np.float32)
128
 
129
+ # Expand masks to 3 channels
130
  mask_fg_3 = np.stack([mask_fg]*3, axis=-1)
131
  mask_mg_3 = np.stack([mask_mg]*3, axis=-1)
132
  mask_bg_3 = np.stack([mask_bg]*3, axis=-1)
133
 
134
+ # Blend the images using the masks (vectorized operation)
135
  final_img = (img_foreground * mask_fg_3 +
136
  img_middleground * mask_mg_3 +
137
  img_background * mask_bg_3).astype(np.uint8)
 
139
  final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
140
  lensBlurImage = Image.fromarray(final_img_rgb)
141
 
142
+ # Create mask images for display (scaled to 0-255)
143
  mask_fg_img = Image.fromarray((mask_fg * 255).astype(np.uint8))
144
  mask_mg_img = Image.fromarray((mask_mg * 255).astype(np.uint8))
145
  mask_bg_img = Image.fromarray((mask_bg * 255).astype(np.uint8))
 
158
  4. Depth-based lens blur effect.
159
  5. Depth-based masks for foreground, middleground, and background.
160
  """
161
+ seg_blur, seg_mask = segmentation_blur_effect(input_image)
162
  depth_map_img, lens_blur_img, mask_fg_img, mask_mg_img, mask_bg_img = lens_blur_effect(input_image)
163
 
164
  return (
 
176
  "Upload an image to apply two distinct effects:\n\n"
177
  "1. A segmentation-based Gaussian blur that blurs the background (using RMBG-2.0).\n"
178
  "2. A depth-based lens blur effect that simulates realistic lens blur based on depth (using DepthPro).\n\n"
179
+ "Outputs include the blurred image, segmentation mask, depth map, lens-blurred image, and individual depth masks."
180
  )
181
 
182
  demo = gr.Interface(