Spaces:
Running
Running
import gradio as gr | |
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation | |
from PIL import Image, ImageFilter | |
import numpy as np | |
import torch | |
from scipy.ndimage import gaussian_filter | |
# Load the OneFormer processor and model globally (to avoid reloading for each request) | |
processor = None | |
model = None | |
try: | |
processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large") | |
model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_coco_swin_large") | |
except Exception as e: | |
print(f"Error loading OneFormer model: {e}") | |
def apply_gaussian_blur(image, mask, radius): | |
"""Applies Gaussian blur to the background of the image.""" | |
blurred_background = image.filter(ImageFilter.GaussianBlur(radius=radius)) | |
img_array = np.array(image) | |
blurred_array = np.array(blurred_background) | |
foreground_mask = mask > 0 | |
foreground_mask_3d = np.stack([foreground_mask] * 3, axis=-1) | |
final_image_array = np.where(foreground_mask_3d, img_array, blurred_array) | |
return Image.fromarray(final_image_array.astype(np.uint8)) | |
def apply_lens_blur(image, mask, strength): | |
"""Placeholder for Lens Blur function. Will be implemented later.""" | |
# Convert PIL Image to NumPy array | |
img_array = np.array(image) | |
mask_array = np.array(mask) / 255.0 # Normalize mask to 0-1 | |
# Apply a simple blur based on the mask (this is a very basic placeholder) | |
blurred_image = gaussian_filter(img_array, sigma=strength * mask_array[:, :, np.newaxis]) | |
return Image.fromarray(blurred_image.astype(np.uint8)) | |
def segment_and_blur(input_image, blur_type, gaussian_radius=15, lens_strength=5): | |
"""Segments the input image and applies the selected blur.""" | |
if processor is None or model is None: | |
return "Error: OneFormer model not loaded." | |
image = input_image.convert("RGB") | |
# Rotate the image (assuming this is still needed) | |
image = image.rotate(-90, expand=True) | |
# Prepare input for semantic segmentation | |
inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt") | |
# Semantic segmentation | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Processing semantic segmentation output | |
predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] | |
segmentation_mask = predicted_semantic_map.cpu().numpy() | |
# Get the mapping of class IDs to labels | |
id2label = model.config.id2label | |
# Set foreground label to person | |
foreground_label = 'person' | |
foreground_class_id = None | |
for id, label in id2label.items(): | |
if label == foreground_label: | |
foreground_class_id = id | |
break | |
if foreground_class_id is None: | |
return f"Error: Could not find the label '{foreground_label}' in the model's class mapping." | |
# Black background mask | |
output_mask_array = np.zeros(segmentation_mask.shape, dtype=np.uint8) | |
# Set the pixels corresponding to the foreground object to white (255) | |
output_mask_array[segmentation_mask == foreground_class_id] = 255 | |
# Convert the NumPy array to a PIL Image and resize to match input | |
mask_pil = Image.fromarray(output_mask_array, mode='L').resize(image.size) | |
mask_array = np.array(mask_pil) | |
if blur_type == "Gaussian": | |
blurred_image = apply_gaussian_blur(image, mask_array, gaussian_radius) | |
elif blur_type == "Lens": | |
blurred_image = apply_lens_blur(image, mask_array, lens_strength) | |
else: | |
return "Error: Invalid blur type selected." | |
return blurred_image | |
iface = gr.Interface( | |
fn=segment_and_blur, | |
inputs=[ | |
gr.Image(label="Input Image"), | |
gr.Radio(["Gaussian", "Lens"], label="Blur Type", value="Gaussian"), | |
gr.Slider(0, 30, step=1, default=15, label="Gaussian Blur Radius"), | |
gr.Slider(0, 10, step=1, default=5, label="Lens Blur Strength"), | |
], | |
outputs=gr.Image(label="Output Image"), | |
title="Image Background Blur App", | |
description="Upload an image, select a blur type (Gaussian or Lens), and adjust the blur parameters to blur the background while keeping the person in focus." | |
) | |
if __name__ == "__main__": | |
iface.launch() |