LPX55's picture
a
f7e52d6
raw
history blame
2.41 kB
# K-I-S-S
import spaces
import gradio as gr
from gradio_image_prompter import ImagePrompter
from sam2.sam2_image_predictor import SAM2ImagePredictor
import torch
import numpy as np
from PIL import Image as PILImage
# Initialize SAM2 predictor
MODEL = "facebook/sam2.1-hiera-large"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@spaces.GPU()
def predict_masks(image, points):
"""Predict a single mask from the image based on selected points."""
global PREDICTOR
PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE)
image_np = np.array(image)
points_list = [[point["x"], point["y"]] for point in points]
input_labels = [1] * len(points_list)
with torch.inference_mode():
PREDICTOR.set_image(image_np)
masks, _, _ = PREDICTOR.predict(
point_coords=points_list, point_labels=input_labels, multimask_output=False
)
# Prepare the overlay image
red_mask = np.zeros_like(image_np)
if masks and len(masks) > 0:
red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255 # Apply the red channel
red_mask = PILImage.fromarray(red_mask)
original_image = PILImage.fromarray(image_np)
blended_image = PILImage.blend(original_image, red_mask, alpha=0.5)
return np.array(blended_image)
else:
return image_np
def create_sam2_mask_interface():
"""Create the Gradio interface for SAM2 mask generation."""
with gr.Blocks() as sam2_mask_tab:
gr.Markdown("# Object Segmentation with SAM2")
gr.Markdown(
"""
This application utilizes **Segment Anything V2 (SAM2)** to allow you to upload an image and interactively generate a segmentation mask based on multiple points you select on the image.
"""
)
with gr.Row():
with gr.Column():
upload_image_input = ImagePrompter(show_label=False)
submit_button = gr.Button("Submit")
with gr.Column():
image_output = gr.Image(label="Segmented Image", type="pil", height=400)
# Define the action triggered by the submit button
submit_button.click(
fn=predict_masks,
inputs=[upload_image_input.image, upload_image_input.points],
outputs=image_output,
show_progress=True,
)
return sam2_mask_tab