import streamlit as st from PIL import Image, ImageFilter import matplotlib.pyplot as plt import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation from transformers import pipeline import numpy as np import os def depth_based_blur(orig_image: Image.Image, depth_map: Image.Image, max_blur: float = 15, num_bands: int = 10, invert_depth: bool = True) -> Image.Image: """ Apply a depth-based blur effect to the original image with depth map image. Returns: PIL.Image.Image: The final image with background (farther areas) blurred. """ # Convert depth map to a NumPy array (float32) and normalize to [0, 1] depth_array = np.array(depth_map, dtype=np.float32) d_min, d_max = depth_array.min(), depth_array.max() depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8) if invert_depth: depth_norm = 1.0 - depth_norm orig_rgba = orig_image.convert("RGBA") final_image = orig_rgba.copy() band_edges = np.linspace(0, 1, num_bands + 1) for i in range(num_bands): band_min = band_edges[i] band_max = band_edges[i+1] # Use the midpoint of the band to determine the blur strength. mid = (band_min + band_max) / 2.0 # For example, if mid is lower (i.e. farther away) we want more blur. blur_radius = (1 - mid) * max_blur # Create a blurred version of the original image for this band. blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius)) # Create a mask for pixels whose normalized depth is within this band. band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255 band_mask_pil = Image.fromarray(band_mask, mode="L") final_image = Image.composite(blurred_version, final_image, band_mask_pil) # Convert back to RGB and return. return final_image.convert("RGB") def main(): hf_token = os.environ.get("HF_ACCESS_TOKEN") if hf_token is None: raise RuntimeError("HF_ACCESS_TOKEN is not set. Please add it as a secret.") st.title("Custom Background Blur Demo") # 1. Upload an image uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # 2. Open and display the original image image = Image.open(uploaded_file).convert("RGB") orig_width, orig_height = image.size st.image(image, caption="Original Image", use_container_width=True) st.write("---") st.subheader("Blur Settings") col1, col2 = st.columns(2) device = "cpu" #print(device) # added the tokens model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True, use_auth_token=hf_token) torch.set_float32_matmul_precision(['high', 'highest'][0]) model.to(device) model.eval() image_size = (512, 512) transform_image = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) image = image.convert("RGB") input_images = transform_image(image).unsqueeze(0).to(device) # Inference on pytorch with torch.no_grad(): # Get the final output, apply sigmoid to obtain values in [0,1] preds = model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() # Applying threshold for a binary mask threshold = 0.5 binary_mask = (pred > threshold).float() mask_pil = transforms.ToPILImage()(binary_mask) mask_pil = mask_pil.convert("L") # Ensure it's in grayscale mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0) orig_width, orig_height = image.size mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR) #blur_radius = 15 # adjust radius to control blur strength depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf") #resized_image = image.resize((512, 512)) results = depth_pipeline(image) #print(results) depth_map_image = results['depth'] with col1: gauss_radius = st.slider("Gaussian Blur Radius", 0, 30, 10, key="gauss") #gaussian_blurred = image.filter(ImageFilter.GaussianBlur(gauss_radius)) blurred_image = image.filter(ImageFilter.GaussianBlur(gauss_radius)) # background is blurred # White (255) in mask_pil = from image1 (orig_image) # Black (0) in mask_pil = from image2 (blurred_image) final_image = Image.composite(image, blurred_image, mask_pil) st.image( final_image, caption=f"Gaussian Blur (radius={gauss_radius})", use_container_width=True ) with col2: blur_max = st.slider("Lens Blur Radius", 0, 5, 1, key="lens") output_image = depth_based_blur(image, depth_map_image, max_blur=blur_max, num_bands=40, invert_depth=False) st.image( output_image, caption=f"Lens Blur (blur={blur_max})", use_container_width=True ) if __name__ == "__main__": main()