Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import cv2 | |
from PIL import Image | |
from transformers import pipeline | |
import gradio as gr | |
# ===== Device Setup ===== | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
device_index = 0 if torch.cuda.is_available() else -1 | |
# ===== MiDaS Depth Estimation Setup ===== | |
# Load MiDaS model and transforms | |
midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large") | |
midas.to(device).eval() | |
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") | |
transform = midas_transforms.dpt_transform | |
# ===== Segmentation Setup ===== | |
segmenter = pipeline( | |
"image-segmentation", | |
model="nvidia/segformer-b0-finetuned-ade-512-512", | |
device=device_index, | |
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32 | |
) | |
# ===== Utility Functions ===== | |
def resize_image(img: Image.Image, max_size: int = 512) -> Image.Image: | |
width, height = img.size | |
if max(width, height) > max_size: | |
ratio = max_size / max(width, height) | |
new_size = (int(width * ratio), int(height * ratio)) | |
return img.resize(new_size, Image.LANCZOS) | |
return img | |
# ===== Depth Prediction ===== | |
def predict_depth(image: Image.Image) -> Image.Image: | |
# Ensure input is PIL Image | |
img = image.convert('RGB') if not isinstance(image, Image.Image) else image | |
img_np = np.array(img) | |
# Convert to the format expected by MiDaS | |
input_tensor = transform(img_np).to(device) | |
input_batch = input_tensor.unsqueeze(0) if input_tensor.ndim == 3 else input_tensor | |
# Predict depth | |
with torch.no_grad(): | |
prediction = midas(input_batch) | |
prediction = torch.nn.functional.interpolate( | |
prediction.unsqueeze(1), | |
size=img_np.shape[:2], | |
mode="bicubic", | |
align_corners=False | |
).squeeze() | |
# Normalize to 0-255 | |
depth_map = prediction.cpu().numpy() | |
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) | |
depth_map = (depth_map * 255).astype(np.uint8) | |
return Image.fromarray(depth_map) | |
# ===== Segmentation ===== | |
def segment_image(img: Image.Image) -> Image.Image: | |
img = img.convert('RGB') | |
img_resized = resize_image(img) | |
results = segmenter(img_resized) | |
overlay = np.array(img_resized, dtype=np.uint8) | |
for res in results: | |
mask = np.array(res["mask"], dtype=bool) | |
color = np.random.randint(50, 255, 3, dtype=np.uint8) | |
overlay[mask] = (overlay[mask] * 0.6 + color * 0.4).astype(np.uint8) | |
return Image.fromarray(overlay) | |
# ===== Gradio App ===== | |
def predict_fn(input_img: Image.Image) -> Image.Image: | |
# 1. Compute depth map | |
depth_img = predict_depth(input_img) | |
# 2. Segment the depth map | |
seg_img = segment_image(depth_img) | |
return seg_img | |
iface = gr.Interface( | |
fn=predict_fn, | |
inputs=gr.Image(type="pil", label="Upload Image"), | |
outputs=gr.Image(type="pil", label="Segmented Depth Overlay"), | |
title="Depth-then-Segmentation Pipeline", | |
description="Upload an image. First computes a depth map via MiDaS, then applies SegFormer segmentation on the depth map." | |
) | |
if __name__ == "__main__": | |
iface.launch() |