Ash2505 commited on
Commit
e28b51d
·
verified ·
1 Parent(s): d3e5ae0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -154
app.py CHANGED
@@ -1,200 +1,212 @@
1
- import gradio as gr
2
- from PIL import Image, ImageFilter
3
- # import matplotlib.pyplot as plt
4
- import torch
5
  import cv2
6
  import numpy as np
 
 
 
7
  from torchvision import transforms
8
- from transformers import AutoModelForImageSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation
9
- import requests
10
-
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
-
13
- birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
14
- torch.set_float32_matmul_precision(['high', 'highest'][0])
15
- birefnet.to(device)
16
- birefnet.eval()
17
- birefnet.half()
18
-
19
- def extract_object(image, t1, t2):
20
- # Data settings
21
- imageResized = image.resize((512, 512))
22
- image_size = (1024, 1024)
23
- transform_image = transforms.Compose([
24
- transforms.Resize(image_size),
25
- transforms.ToTensor(),
26
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
27
- ])
28
 
29
- # image = Image.open(imagepath)
30
- image1 = image.copy()
31
- input_images = transform_image(image1).unsqueeze(0).to(device).half()
32
 
33
- # Prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  with torch.no_grad():
35
- preds = birefnet(input_images)[-1].sigmoid().cpu()
36
  pred = preds[0].squeeze()
 
 
37
  pred_pil = transforms.ToPILImage()(pred)
38
- mask = pred_pil.resize(image1.size)
39
- image1.putalpha(mask)
40
-
41
- blurredBg = cv2.GaussianBlur(np.array(imageResized), (0, 0), sigmaX=15, sigmaY=15)
42
-
43
- mask = np.array(result[1].convert("L"))
44
- _, maskBinary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
45
  img = cv2.cvtColor(np.array(imageResized), cv2.COLOR_RGB2BGR)
 
 
46
 
 
47
  maskInv = cv2.bitwise_not(maskBinary)
48
  maskInv3 = cv2.cvtColor(maskInv, cv2.COLOR_GRAY2BGR)
49
 
 
50
  foreground = cv2.bitwise_and(img, cv2.bitwise_not(maskInv3))
51
  background = cv2.bitwise_and(blurredBg, maskInv3)
 
 
52
  finalImg = cv2.add(cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), background)
53
-
54
- # plt.figure(figsize=(15, 5))
55
- # return image1, mask
56
-
57
- # def depth_estimation():
58
- imageProcessor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
59
- model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device)
60
 
61
- inputs = imageProcessor(images=imageResized, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
62
 
 
 
 
 
 
 
 
 
 
63
  with torch.no_grad():
64
- outputs = model(**inputs)
65
-
66
- post_processed_output = imageProcessor.post_process_depth_estimation(
67
- outputs, target_sizes=[(imageResized.height, imageResized.width)],
68
  )
69
-
70
- field_of_view = post_processed_output[0]["field_of_view"]
71
- focal_length = post_processed_output[0]["focal_length"]
72
  depth = post_processed_output[0]["predicted_depth"]
 
 
73
  depth = (depth - depth.min()) / (depth.max() - depth.min())
74
  depth = depth * 255.
75
  depth = depth.detach().cpu().numpy()
76
- # print(depth)
77
- depthImg = Image.fromarray(depth.astype("uint8"))
78
 
79
- # threshold1 = 255 / 20 # ~85
80
- # threshold2 = 2 * 255 / 3 # ~170
81
-
82
- threshold1 = (t1/10) * 255
83
- threshold2 = (t2/10) * 255
84
 
85
- # Precompute blurred versions for each region
86
  img_foreground = img.copy() # No blur for foreground
87
  img_middleground = cv2.GaussianBlur(img, (0, 0), sigmaX=7, sigmaY=7)
88
  img_background = cv2.GaussianBlur(img, (0, 0), sigmaX=15, sigmaY=15)
89
 
90
- # Create masks for each region (as float arrays for proper blending)
91
- mask_fg = (depth < threshold1).astype(np.float32)
92
- mask_mg = ((depth >= threshold1) & (depth < threshold2)).astype(np.float32)
93
- mask_bg = (depth >= threshold2).astype(np.float32)
94
 
95
- # Expand masks to 3 channels (H, W, 3)
96
- mask_fg = np.stack([mask_fg]*3, axis=-1)
97
- mask_mg = np.stack([mask_mg]*3, axis=-1)
98
- mask_bg = np.stack([mask_bg]*3, axis=-1)
99
 
100
- # Combine the images using the masks in a vectorized manner.
101
- final_img = (img_foreground * mask_fg +
102
- img_middleground * mask_mg +
103
- img_background * mask_bg).astype(np.uint8)
 
 
 
 
 
104
 
105
- # Convert the result back to RGB for display with matplotlib.
106
  final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
 
 
 
 
 
 
 
 
107
 
108
- print("BOTH OUTPUT COMPUTED")
109
-
110
- return image1, final_img
111
-
112
- # Visualization
113
- # plt.axis("off")
114
- # subplots for 3 images: original, segmented, mask
115
-
116
- # plt.figure(figsize=(15, 5))
117
-
118
- # image = Image.open('/content/drive/MyDrive/eee515-hw3/hw3-q24.jpg')
119
- # #resize the image to 512x512
120
- # imageResized = image.resize((512, 512))
121
-
122
- # result = extract_object(birefnet, imageResized)
123
- # plt.subplot(1, 3, 1)
124
- # plt.title("Original Resized Image")
125
- # plt.imshow(imageResized)
126
-
127
- # plt.subplot(1, 3, 2)
128
- # plt.title("Segmented Image")
129
- # plt.imshow(result[0])
130
-
131
- # plt.subplot(1, 3, 3)
132
- # plt.title("Mask")
133
- # plt.imshow(result[1], cmap="gray")
134
- # plt.show()
135
-
136
- # Create a Gradio interface
137
-
138
-
139
- def build_interface(image1, image2):
140
- """Build UI for gradio app
141
  """
142
- title = "Bokeh and Lens Blur"
143
- with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface:
144
- with gr.Row():
145
- # with gr.Column(scale=3):
146
- # with gr.Group():
147
- # input_text_box = gr.Textbox(
148
- # value=None,
149
- # label="Prompt",
150
- # lines=2,
151
- # )
152
- # # gr.Markdown("### Set the values for Middleground and Background")
153
- # # fg = gr.Slider(minimum=0, maximum=99, step=1, value=33, label="Middleground")
154
- # # mg = gr.Slider(minimum=0, maximum=99, step=1, value=66, label="Background")
155
- # with gr.Row():
156
- # submit_button = gr.Button("Submit", variant="primary")
157
- with gr.Column(scale=3):
158
- model3d = gr.Model3D(
159
- label="Output", height="45em", interactive=False
160
- )
161
-
162
- with gr.Column(scale=3):
163
- model3d = gr.Model3D(
164
- label="Output", height="45em", interactive=False
165
- )
166
-
167
- submit_button.click(
168
- handle_text_prompt,
169
- inputs=[
170
- input_text_box,
171
- variance
172
- ],
173
- outputs=[
174
- model3d
175
- ]
176
- )
177
-
178
- return interface
179
-
180
- # demo = gr.Interface(sepia, gr.Image(), "image")
181
 
182
- title = "Gaussian Blur Background App"
183
  description = (
184
- "Upload an image to apply a realistic background blur effect. "
185
- "The app segments the foreground using RMBG-2.0 and then applies a Gaussian "
186
- "blur (σ=15) to the background, simulating a video conferencing blur effect."
 
187
  )
188
 
189
  demo = gr.Interface(
190
- fn=extract_object,
191
- inputs=[gr.Image(type="pil", label="Input Image"), gr.Slider(minimum=0, maximum=40, step=1, value=33, label="Middleground"), gr.Slider(minimum=40, maximum=99, step=1, value=66, label="Background")],
192
- outputs=[gr.Image(type="pil", label="Bokeh Image"), gr.Image(type="pil", label="Lens Blur Image")],
 
 
 
 
 
 
 
 
193
  title=title,
194
  description=description,
195
  allow_flagging="never"
196
  )
197
 
198
- # demo = build_interface()
199
- # demo.queue(default_concurrency_limit=1)
200
- demo.launch()
 
 
 
 
 
1
  import cv2
2
  import numpy as np
3
+ from PIL import Image, ImageFilter
4
+ import torch
5
+ import gradio as gr
6
  from torchvision import transforms
7
+ from transformers import (
8
+ AutoModelForImageSegmentation,
9
+ DepthProImageProcessorFast,
10
+ DepthProForDepthEstimation,
11
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Set device
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
15
 
16
+ # -----------------------------
17
+ # Load Segmentation Model (RMBG-2.0 by briaai)
18
+ # -----------------------------
19
+ seg_model = AutoModelForImageSegmentation.from_pretrained(
20
+ "briaai/RMBG-2.0", trust_remote_code=True
21
+ )
22
+ # Set higher precision for matmul if desired
23
+ torch.set_float32_matmul_precision(["high", "highest"][0])
24
+ seg_model.to(device)
25
+ seg_model.eval()
26
+
27
+ # Define segmentation image size and transform
28
+ seg_image_size = (1024, 1024)
29
+ seg_transform = transforms.Compose([
30
+ transforms.Resize(seg_image_size),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
33
+ ])
34
+
35
+ # -----------------------------
36
+ # Load Depth Estimation Model (DepthPro by Apple)
37
+ # -----------------------------
38
+ depth_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
39
+ depth_model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf")
40
+ depth_model.to(device)
41
+ depth_model.eval()
42
+
43
+ # -----------------------------
44
+ # Define the Segmentation-Based Blur Effect
45
+ # -----------------------------
46
+ def segmentation_blur_effect(input_image: Image.Image):
47
+ """
48
+ Creates a segmentation mask using RMBG-2.0 and applies a Gaussian blur (sigma=15)
49
+ to the background while keeping the foreground sharp.
50
+
51
+ Returns:
52
+ - final segmented and blurred image (PIL Image)
53
+ - segmentation mask (PIL Image)
54
+ - blurred background image (PIL Image) [optional display]
55
+ """
56
+ # Resize input for segmentation processing
57
+ imageResized = input_image.resize(seg_image_size)
58
+ input_tensor = seg_transform(imageResized).unsqueeze(0).to(device)
59
+
60
  with torch.no_grad():
61
+ preds = seg_model(input_tensor)[-1].sigmoid().cpu()
62
  pred = preds[0].squeeze()
63
+
64
+ # Convert predicted mask to a PIL image and resize to original input size
65
  pred_pil = transforms.ToPILImage()(pred)
66
+ mask = pred_pil.resize(input_image.size)
67
+
68
+ # Create a binary mask (convert to grayscale, then threshold)
69
+ mask_np = np.array(mask.convert("L"))
70
+ _, maskBinary = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
71
+
72
+ # Convert the resized image to an OpenCV BGR array
73
  img = cv2.cvtColor(np.array(imageResized), cv2.COLOR_RGB2BGR)
74
+ # Apply Gaussian blur (sigmaX=15, sigmaY=15)
75
+ blurredBg = cv2.GaussianBlur(np.array(imageResized), (0, 0), sigmaX=15, sigmaY=15)
76
 
77
+ # Create the inverse mask and convert to 3 channels
78
  maskInv = cv2.bitwise_not(maskBinary)
79
  maskInv3 = cv2.cvtColor(maskInv, cv2.COLOR_GRAY2BGR)
80
 
81
+ # Extract the foreground and background separately
82
  foreground = cv2.bitwise_and(img, cv2.bitwise_not(maskInv3))
83
  background = cv2.bitwise_and(blurredBg, maskInv3)
84
+
85
+ # Combine the two components
86
  finalImg = cv2.add(cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), background)
87
+ finalImg_pil = Image.fromarray(finalImg)
88
+ blurredBg_pil = Image.fromarray(cv2.cvtColor(blurredBg, cv2.COLOR_BGR2RGB))
 
 
 
 
 
89
 
90
+ return finalImg_pil, mask, blurredBg_pil
91
+
92
+ # -----------------------------
93
+ # Define the Depth-Based Lens Blur Effect
94
+ # -----------------------------
95
+ def lens_blur_effect(input_image: Image.Image):
96
+ """
97
+ Uses DepthPro to estimate a depth map and applies a dynamic lens blur effect
98
+ by precomputing three versions of the image (foreground, middleground, background)
99
+ with increasing blur. Regions are blended based on the estimated depth.
100
 
101
+ Returns:
102
+ - Depth map (PIL Image)
103
+ - Final lens-blurred image (PIL Image)
104
+ - Foreground mask (PIL Image)
105
+ - Middleground mask (PIL Image)
106
+ - Background mask (PIL Image)
107
+ """
108
+ # Process the image with the depth estimation model
109
+ inputs = depth_processor(images=input_image, return_tensors="pt").to(device)
110
  with torch.no_grad():
111
+ outputs = depth_model(**inputs)
112
+ post_processed_output = depth_processor.post_process_depth_estimation(
113
+ outputs, target_sizes=[(input_image.height, input_image.width)]
 
114
  )
 
 
 
115
  depth = post_processed_output[0]["predicted_depth"]
116
+
117
+ # Normalize depth to [0, 255]
118
  depth = (depth - depth.min()) / (depth.max() - depth.min())
119
  depth = depth * 255.
120
  depth = depth.detach().cpu().numpy()
121
+ depth_map = depth.astype(np.uint8)
122
+ depthImg = Image.fromarray(depth_map)
123
 
124
+ # Convert input image to OpenCV BGR format
125
+ img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
 
 
 
126
 
127
+ # Precompute three blurred versions of the image
128
  img_foreground = img.copy() # No blur for foreground
129
  img_middleground = cv2.GaussianBlur(img, (0, 0), sigmaX=7, sigmaY=7)
130
  img_background = cv2.GaussianBlur(img, (0, 0), sigmaX=15, sigmaY=15)
131
 
132
+ # Define depth thresholds (using 1/3 and 2/3 of 255)
133
+ threshold1 = 255 / 3 # ~85
134
+ threshold2 = 2 * 255 / 3 # ~170
 
135
 
136
+ # Create masks for the three regions based on depth
137
+ mask_fg = (depth_map < threshold1).astype(np.float32)
138
+ mask_mg = ((depth_map >= threshold1) & (depth_map < threshold2)).astype(np.float32)
139
+ mask_bg = (depth_map >= threshold2).astype(np.float32)
140
 
141
+ # Expand masks to 3 channels to match image dimensions
142
+ mask_fg_3 = np.stack([mask_fg]*3, axis=-1)
143
+ mask_mg_3 = np.stack([mask_mg]*3, axis=-1)
144
+ mask_bg_3 = np.stack([mask_bg]*3, axis=-1)
145
+
146
+ # Combine the images using the masks (vectorized blending)
147
+ final_img = (img_foreground * mask_fg_3 +
148
+ img_middleground * mask_mg_3 +
149
+ img_background * mask_bg_3).astype(np.uint8)
150
 
 
151
  final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
152
+ lensBlurImage = Image.fromarray(final_img_rgb)
153
+
154
+ # Create mask images (scaled to 0-255)
155
+ mask_fg_img = Image.fromarray((mask_fg * 255).astype(np.uint8))
156
+ mask_mg_img = Image.fromarray((mask_mg * 255).astype(np.uint8))
157
+ mask_bg_img = Image.fromarray((mask_bg * 255).astype(np.uint8))
158
+
159
+ return depthImg, lensBlurImage, mask_fg_img, mask_mg_img, mask_bg_img
160
 
161
+ # -----------------------------
162
+ # Gradio App: Process Image and Display Multiple Effects
163
+ # -----------------------------
164
+ def process_image(input_image: Image.Image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  """
166
+ Processes the uploaded image to generate:
167
+ 1. Segmentation-based Gaussian blur effect.
168
+ 2. Segmentation mask.
169
+ 3. Depth map.
170
+ 4. Depth-based lens blur effect.
171
+ 5. Depth-based masks for foreground, middleground, and background.
172
+ """
173
+ seg_blur, seg_mask, _ = segmentation_blur_effect(input_image)
174
+ depth_map_img, lens_blur_img, mask_fg_img, mask_mg_img, mask_bg_img = lens_blur_effect(input_image)
175
+
176
+ return (
177
+ seg_blur,
178
+ seg_mask,
179
+ depth_map_img,
180
+ lens_blur_img,
181
+ mask_fg_img,
182
+ mask_mg_img,
183
+ mask_bg_img
184
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ title = "Blur Effects: Gaussian Blur & Depth-Based Lens Blur"
187
  description = (
188
+ "Upload an image to apply two distinct effects:\n\n"
189
+ "1. A segmentation-based Gaussian blur that blurs the background (using RMBG-2.0).\n"
190
+ "2. A depth-based lens blur effect that simulates realistic lens blur based on depth (using DepthPro).\n\n"
191
+ "Outputs include the blurred image, segmentation mask, depth map, lens-blurred image, and depth masks."
192
  )
193
 
194
  demo = gr.Interface(
195
+ fn=process_image,
196
+ inputs=gr.Image(type="pil", label="Input Image"),
197
+ outputs=[
198
+ gr.Image(type="pil", label="Segmentation-Based Blur"),
199
+ gr.Image(type="pil", label="Segmentation Mask"),
200
+ gr.Image(type="pil", label="Depth Map"),
201
+ gr.Image(type="pil", label="Depth-Based Lens Blur"),
202
+ gr.Image(type="pil", label="Foreground Depth Mask"),
203
+ gr.Image(type="pil", label="Middleground Depth Mask"),
204
+ gr.Image(type="pil", label="Background Depth Mask")
205
+ ],
206
  title=title,
207
  description=description,
208
  allow_flagging="never"
209
  )
210
 
211
+ if __name__ == "__main__":
212
+ demo.launch()