Spaces:
Running
on
Zero
Running
on
Zero
# K-I-S-S | |
import spaces | |
import gradio as gr | |
from gradio_image_prompter import ImagePrompter | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
import torch | |
import numpy as np | |
from PIL import Image as PILImage | |
# Initialize SAM2 predictor | |
MODEL = "facebook/sam2.1-hiera-large" | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def predict_masks(image, points): | |
"""Predict a single mask from the image based on selected points.""" | |
global PREDICTOR | |
PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE) | |
image_np = np.array(image) | |
points_list = [[point["x"], point["y"]] for point in points] | |
input_labels = [1] * len(points_list) | |
with torch.inference_mode(): | |
PREDICTOR.set_image(image_np) | |
masks, _, _ = PREDICTOR.predict( | |
point_coords=points_list, point_labels=input_labels, multimask_output=False | |
) | |
# Prepare the overlay image | |
red_mask = np.zeros_like(image_np) | |
if masks and len(masks) > 0: | |
red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255 # Apply the red channel | |
red_mask = PILImage.fromarray(red_mask) | |
original_image = PILImage.fromarray(image_np) | |
blended_image = PILImage.blend(original_image, red_mask, alpha=0.5) | |
return np.array(blended_image) | |
else: | |
return image_np | |
def create_sam2_mask_interface(): | |
"""Create the Gradio interface for SAM2 mask generation.""" | |
with gr.Blocks() as sam2_mask_tab: | |
gr.Markdown("# Object Segmentation with SAM2") | |
gr.Markdown( | |
""" | |
This application utilizes **Segment Anything V2 (SAM2)** to allow you to upload an image and interactively generate a segmentation mask based on multiple points you select on the image. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
upload_image_input = ImagePrompter(show_label=False) | |
submit_button = gr.Button("Submit") | |
with gr.Column(): | |
image_output = gr.Image(label="Segmented Image", type="pil", height=400) | |
# Define the action triggered by the submit button | |
submit_button.click( | |
fn=predict_masks, | |
inputs=[upload_image_input.image, upload_image_input.points], | |
outputs=image_output, | |
show_progress=True, | |
) | |
return sam2_mask_tab |