gchallar's picture
Update app.py
f59b093 verified
raw
history blame
4.24 kB
import gradio as gr
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
from PIL import Image, ImageFilter
import numpy as np
import torch
from scipy.ndimage import gaussian_filter
# Load the OneFormer processor and model globally (to avoid reloading for each request)
processor = None
model = None
try:
processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large")
model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_coco_swin_large")
except Exception as e:
print(f"Error loading OneFormer model: {e}")
def apply_gaussian_blur(image, mask, radius):
"""Applies Gaussian blur to the background of the image."""
blurred_background = image.filter(ImageFilter.GaussianBlur(radius=radius))
img_array = np.array(image)
blurred_array = np.array(blurred_background)
foreground_mask = mask > 0
foreground_mask_3d = np.stack([foreground_mask] * 3, axis=-1)
final_image_array = np.where(foreground_mask_3d, img_array, blurred_array)
return Image.fromarray(final_image_array.astype(np.uint8))
def apply_lens_blur(image, mask, strength):
"""Placeholder for Lens Blur function. Will be implemented later."""
# Convert PIL Image to NumPy array
img_array = np.array(image)
mask_array = np.array(mask) / 255.0 # Normalize mask to 0-1
# Apply a simple blur based on the mask (this is a very basic placeholder)
blurred_image = gaussian_filter(img_array, sigma=strength * mask_array[:, :, np.newaxis])
return Image.fromarray(blurred_image.astype(np.uint8))
def segment_and_blur(input_image, blur_type, gaussian_radius=15, lens_strength=5):
"""Segments the input image and applies the selected blur."""
if processor is None or model is None:
return "Error: OneFormer model not loaded."
image = input_image.convert("RGB")
# Rotate the image (assuming this is still needed)
image = image.rotate(-90, expand=True)
# Prepare input for semantic segmentation
inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt")
# Semantic segmentation
with torch.no_grad():
outputs = model(**inputs)
# Processing semantic segmentation output
predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
segmentation_mask = predicted_semantic_map.cpu().numpy()
# Get the mapping of class IDs to labels
id2label = model.config.id2label
# Set foreground label to person
foreground_label = 'person'
foreground_class_id = None
for id, label in id2label.items():
if label == foreground_label:
foreground_class_id = id
break
if foreground_class_id is None:
return f"Error: Could not find the label '{foreground_label}' in the model's class mapping."
# Black background mask
output_mask_array = np.zeros(segmentation_mask.shape, dtype=np.uint8)
# Set the pixels corresponding to the foreground object to white (255)
output_mask_array[segmentation_mask == foreground_class_id] = 255
# Convert the NumPy array to a PIL Image and resize to match input
mask_pil = Image.fromarray(output_mask_array, mode='L').resize(image.size)
mask_array = np.array(mask_pil)
if blur_type == "Gaussian":
blurred_image = apply_gaussian_blur(image, mask_array, gaussian_radius)
elif blur_type == "Lens":
blurred_image = apply_lens_blur(image, mask_array, lens_strength)
else:
return "Error: Invalid blur type selected."
return blurred_image
iface = gr.Interface(
fn=segment_and_blur,
inputs=[
gr.Image(label="Input Image"),
gr.Radio(["Gaussian", "Lens"], label="Blur Type", value="Gaussian"),
gr.Slider(0, 30, step=1, default=15, label="Gaussian Blur Radius"),
gr.Slider(0, 10, step=1, default=5, label="Lens Blur Strength"),
],
outputs=gr.Image(label="Output Image"),
title="Image Background Blur App",
description="Upload an image, select a blur type (Gaussian or Lens), and adjust the blur parameters to blur the background while keeping the person in focus."
)
if __name__ == "__main__":
iface.launch()