EEE515-HW3 / app.py
Ash2505's picture
Update app.py
323a842 verified
raw
history blame
6.95 kB
import gradio as gr
from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
import torch
import cv2
import numpy as np
from torchvision import transforms
from transformers import AutoModelForImageSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation
import requests
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
torch.set_float32_matmul_precision(['high', 'highest'][0])
birefnet.to('cuda')
birefnet.eval()
birefnet.half()
def extract_object(image, t1, t2):
# Data settings
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# image = Image.open(imagepath)
image1 = image.copy()
input_images = transform_image(image1).unsqueeze(0).to('cuda').half()
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image1.size)
image1.putalpha(mask)
blurredBg = cv2.GaussianBlur(np.array(imageResized), (0, 0), sigmaX=15, sigmaY=15)
mask = np.array(result[1].convert("L"))
_, maskBinary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
img = cv2.cvtColor(np.array(imageResized), cv2.COLOR_RGB2BGR)
maskInv = cv2.bitwise_not(maskBinary)
maskInv3 = cv2.cvtColor(maskInv, cv2.COLOR_GRAY2BGR)
foreground = cv2.bitwise_and(img, cv2.bitwise_not(maskInv3))
background = cv2.bitwise_and(blurredBg, maskInv3)
finalImg = cv2.add(cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), background)
# plt.figure(figsize=(15, 5))
# return image1, mask
# def depth_estimation():
imageProcessor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device)
inputs = imageProcessor(images=imageResized, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
post_processed_output = imageProcessor.post_process_depth_estimation(
outputs, target_sizes=[(imageResized.height, imageResized.width)],
)
field_of_view = post_processed_output[0]["field_of_view"]
focal_length = post_processed_output[0]["focal_length"]
depth = post_processed_output[0]["predicted_depth"]
depth = (depth - depth.min()) / (depth.max() - depth.min())
depth = depth * 255.
depth = depth.detach().cpu().numpy()
# print(depth)
depthImg = Image.fromarray(depth.astype("uint8"))
# threshold1 = 255 / 20 # ~85
# threshold2 = 2 * 255 / 3 # ~170
threshold1 = (t1/10) * 255
threshold2 = (t2/10) * 255
# Precompute blurred versions for each region
img_foreground = img.copy() # No blur for foreground
img_middleground = cv2.GaussianBlur(img, (0, 0), sigmaX=7, sigmaY=7)
img_background = cv2.GaussianBlur(img, (0, 0), sigmaX=15, sigmaY=15)
# Create masks for each region (as float arrays for proper blending)
mask_fg = (depth < threshold1).astype(np.float32)
mask_mg = ((depth >= threshold1) & (depth < threshold2)).astype(np.float32)
mask_bg = (depth >= threshold2).astype(np.float32)
# Expand masks to 3 channels (H, W, 3)
mask_fg = np.stack([mask_fg]*3, axis=-1)
mask_mg = np.stack([mask_mg]*3, axis=-1)
mask_bg = np.stack([mask_bg]*3, axis=-1)
# Combine the images using the masks in a vectorized manner.
final_img = (img_foreground * mask_fg +
img_middleground * mask_mg +
img_background * mask_bg).astype(np.uint8)
# Convert the result back to RGB for display with matplotlib.
final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
return image1, final_img
# Visualization
# plt.axis("off")
# subplots for 3 images: original, segmented, mask
# plt.figure(figsize=(15, 5))
# image = Image.open('/content/drive/MyDrive/eee515-hw3/hw3-q24.jpg')
# #resize the image to 512x512
# imageResized = image.resize((512, 512))
# result = extract_object(birefnet, imageResized)
# plt.subplot(1, 3, 1)
# plt.title("Original Resized Image")
# plt.imshow(imageResized)
# plt.subplot(1, 3, 2)
# plt.title("Segmented Image")
# plt.imshow(result[0])
# plt.subplot(1, 3, 3)
# plt.title("Mask")
# plt.imshow(result[1], cmap="gray")
# plt.show()
# Create a Gradio interface
def build_interface(image1, image2):
"""Build UI for gradio app
"""
title = "Bokeh and Lens Blur"
with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface:
with gr.Row():
# with gr.Column(scale=3):
# with gr.Group():
# input_text_box = gr.Textbox(
# value=None,
# label="Prompt",
# lines=2,
# )
# # gr.Markdown("### Set the values for Middleground and Background")
# # fg = gr.Slider(minimum=0, maximum=99, step=1, value=33, label="Middleground")
# # mg = gr.Slider(minimum=0, maximum=99, step=1, value=66, label="Background")
# with gr.Row():
# submit_button = gr.Button("Submit", variant="primary")
with gr.Column(scale=3):
model3d = gr.Model3D(
label="Output", height="45em", interactive=False
)
with gr.Column(scale=3):
model3d = gr.Model3D(
label="Output", height="45em", interactive=False
)
submit_button.click(
handle_text_prompt,
inputs=[
input_text_box,
variance
],
outputs=[
model3d
]
)
return interface
# demo = gr.Interface(sepia, gr.Image(), "image")
title = "Gaussian Blur Background App"
description = (
"Upload an image to apply a realistic background blur effect. "
"The app segments the foreground using RMBG-2.0 and then applies a Gaussian "
"blur (σ=15) to the background, simulating a video conferencing blur effect."
)
iface = gr.Interface(
fn=apply_blur_effect,
inputs=[gr.Image(type="pil", label="Input Image"), gr.Slider(minimum=0, maximum=40, step=1, value=33, label="Middleground"), gr.Slider(minimum=40, maximum=99, step=1, value=66, label="Background")],
outputs=[gr.Image(type="pil", label="Bokeh Image"), gr.Image(type="pil", label="Lens Blur Image")],
title=title,
description=description,
allow_flagging="never"
)
demo = build_interface()
demo.queue(default_concurrency_limit=1)
demo.launch()