import gradio as gr import torch import numpy as np from transformers import AutoImageProcessor, AutoModelForDepthEstimation from PIL import Image, ImageFilter import matplotlib.pyplot as plt import matplotlib.cm as cm # --------------------------- # Depth Estimation Utilities # --------------------------- def compute_depth_map(image: Image.Image, scale_factor: float) -> np.ndarray: """ Loads the LiheYoung/depth-anything-large-hf model and computes a depth map. The depth map is normalized, inverted (so that near=0 and far=1), and multiplied by the given scale_factor. """ # Load model and processor from pretrained weights image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-large-hf") model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-large-hf") # Prepare image for the model inputs = image_processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) predicted_depth = outputs.predicted_depth # Interpolate predicted depth map to match image size prediction = torch.nn.functional.interpolate( predicted_depth.unsqueeze(1), size=image.size[::-1], # PIL image size is (width, height) mode="bicubic", align_corners=False, ) # Normalize for visualization depth_min = prediction.min() depth_max = prediction.max() depth_vis = (prediction - depth_min) / (depth_max - depth_min + 1e-8) depth_map = depth_vis.squeeze().cpu().numpy() # Invert so that near=0 and far=1, then scale depth_map_inverted = 1.0 - depth_map depth_map_inverted *= scale_factor return depth_map_inverted # --------------------------- # Depth-Based Blur Functions # --------------------------- def layered_blur(image: Image.Image, depth_map: np.ndarray, num_layers: int, max_blur: float) -> Image.Image: """ Creates multiple blurred versions of the image (using Gaussian blur with radii from 0 to max_blur) and composites them using masks generated from bins of the normalized depth map. """ blur_radii = np.linspace(0, max_blur, num_layers) blur_versions = [image.filter(ImageFilter.GaussianBlur(radius)) for radius in blur_radii] # Use a fixed range (0 to 1) since the depth map is normalized thresholds = np.linspace(0, 1, num_layers + 1) final_image = blur_versions[-1] for i in range(num_layers - 1, -1, -1): mask_array = np.logical_and( depth_map >= thresholds[i], depth_map < thresholds[i + 1] ).astype(np.uint8) * 255 mask_image = Image.fromarray(mask_array, mode="L") final_image = Image.composite(blur_versions[i], final_image, mask_image) return final_image def process_depth_blur(uploaded_image, max_blur_value, scale_factor, num_layers): """ Resizes the uploaded image to 512x512, computes its depth map, and applies layered blur based on the depth. """ if not isinstance(uploaded_image, Image.Image): uploaded_image = Image.open(uploaded_image) image = uploaded_image.convert("RGB").resize((512, 512)) depth_map = compute_depth_map(image, scale_factor) final_image = layered_blur(image, depth_map, int(num_layers), max_blur_value) return final_image # --------------------------- # Depth Heatmap Functions # --------------------------- def create_heatmap(depth_map: np.ndarray, intensity: float) -> Image.Image: """ Applies a colormap to the normalized depth map. The 'intensity' slider multiplies the normalized depth values (clipped to [0,1]) before applying the "inferno" colormap. """ # Multiply depth map by intensity and clip to 0-1 range normalized = np.clip(depth_map * intensity, 0, 1) colormap = cm.get_cmap("inferno") colored = colormap(normalized) # Returns an RGBA image in [0, 1] heatmap = (colored[:, :, :3] * 255).astype(np.uint8) # drop alpha and convert to [0,255] return Image.fromarray(heatmap) def process_depth_heatmap(uploaded_image, intensity): """ Resizes the uploaded image to 512x512, computes its depth map (with scale factor 1.0), and returns a heatmap visualization. """ if not isinstance(uploaded_image, Image.Image): uploaded_image = Image.open(uploaded_image) image = uploaded_image.convert("RGB").resize((512, 512)) depth_map = compute_depth_map(image, scale_factor=1.0) heatmap_img = create_heatmap(depth_map, intensity) return heatmap_img # --- Segmentation-Based Blur using BEN2 --- def load_segmentation_model(): """ Loads and caches the segmentation model from BEN2. Ensure you have ben2 installed and accessible in your path. """ global seg_model, seg_device if "seg_model" not in globals(): from ben2 import BEN_Base # Import BEN2 seg_model = BEN_Base.from_pretrained("PramaLLC/BEN2") seg_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") seg_model.to(seg_device).eval() return seg_model, seg_device def process_segmentation_blur(uploaded_image, seg_blur_radius: float): """ Processes the image with segmentation-based blur. The image is resized to 512x512. A Gaussian blur with the specified radius is applied, then the segmentation mask is computed to composite the sharp foreground over the blurred background. """ if not isinstance(uploaded_image, Image.Image): uploaded_image = Image.open(uploaded_image) image = uploaded_image.convert("RGB").resize((512, 512)) seg_model, seg_device = load_segmentation_model() blurred_image = image.filter(ImageFilter.GaussianBlur(seg_blur_radius)) # Generate segmentation mask (foreground) foreground = seg_model.inference(image, refine_foreground=False) foreground_rgba = foreground.convert("RGBA") _, _, _, alpha = foreground_rgba.split() binary_mask = alpha.point(lambda x: 255 if x > 128 else 0, mode="L") final_image = Image.composite(image, blurred_image, binary_mask) return final_image # --- Merged Gradio Interface --- with gr.Blocks() as demo: gr.Markdown("# Depth-Based vs Segmentation-Based Blur") with gr.Tabs(): with gr.Tab("Depth Blur"): img_input = gr.Image(type="pil", label="Upload Image") blur_slider = gr.Slider(1, 50, value=20, label="Maximum Blur Radius") scale_slider = gr.Slider(0.1, 2.0, value=1.0, label="Depth Scale Factor") layers_slider = gr.Slider(2, 10, value=5, label="Number of Layers") blur_output = gr.Image(label="Depth Blur Result") blur_button = gr.Button("Process Depth Blur") blur_button.click( process_depth_blur, inputs=[img_input, blur_slider, scale_slider, layers_slider], outputs=blur_output ) with gr.Tab("Depth Heatmap"): img_input2 = gr.Image(type="pil", label="Upload Image") intensity_slider = gr.Slider(0.5, 5.0, value=1.0, label="Heatmap Intensity") heatmap_output = gr.Image(label="Depth Heatmap") heatmap_button = gr.Button("Generate Depth Heatmap") heatmap_button.click( process_depth_heatmap, inputs=[img_input2, intensity_slider], outputs=heatmap_output ) with gr.Tab("Segmentation-Based Blur (BEN2)"): seg_img = gr.Image(type="pil", label="Upload Image") seg_blur = gr.Slider(5, 30, value=15, step=1, label="Segmentation Blur Radius") seg_out = gr.Image(label="Segmentation-Based Blurred Image") seg_button = gr.Button("Process Segmentation Blur") seg_button.click(process_segmentation_blur, inputs=[seg_img, seg_blur], outputs=seg_out) if __name__ == "__main__": # Optionally, set share=True to generate a public link. demo.launch(share=True)