File size: 2,411 Bytes
7a13abb
112487d
a88efc1
7a13abb
 
a88efc1
 
7a13abb
 
 
a745151
7a13abb
 
112487d
7a13abb
 
112487d
 
7a13abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a88efc1
7a13abb
 
 
 
 
 
 
 
 
 
 
a88efc1
8fa6d2e
7a13abb
8fa6d2e
 
f7e52d6
7a13abb
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# K-I-S-S
import spaces
import gradio as gr
from gradio_image_prompter import ImagePrompter
from sam2.sam2_image_predictor import SAM2ImagePredictor
import torch
import numpy as np
from PIL import Image as PILImage

# Initialize SAM2 predictor
MODEL = "facebook/sam2.1-hiera-large"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@spaces.GPU()
def predict_masks(image, points):
    """Predict a single mask from the image based on selected points."""
    global PREDICTOR
    PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE)
    image_np = np.array(image)
    points_list = [[point["x"], point["y"]] for point in points]
    input_labels = [1] * len(points_list)

    with torch.inference_mode():
        PREDICTOR.set_image(image_np)
        masks, _, _ = PREDICTOR.predict(
            point_coords=points_list, point_labels=input_labels, multimask_output=False
        )

    # Prepare the overlay image
    red_mask = np.zeros_like(image_np)
    if masks and len(masks) > 0:
        red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255  # Apply the red channel
        red_mask = PILImage.fromarray(red_mask)
        original_image = PILImage.fromarray(image_np)
        blended_image = PILImage.blend(original_image, red_mask, alpha=0.5)
        return np.array(blended_image)
    else:
        return image_np

def create_sam2_mask_interface():
    """Create the Gradio interface for SAM2 mask generation."""
    with gr.Blocks() as sam2_mask_tab:
        gr.Markdown("# Object Segmentation with SAM2")
        gr.Markdown(
            """
            This application utilizes **Segment Anything V2 (SAM2)** to allow you to upload an image and interactively generate a segmentation mask based on multiple points you select on the image.
            """
        )
        with gr.Row():
            with gr.Column():
                upload_image_input = ImagePrompter(show_label=False)
                submit_button = gr.Button("Submit")
            with gr.Column():
                image_output = gr.Image(label="Segmented Image", type="pil", height=400)
        
        # Define the action triggered by the submit button
        submit_button.click(
            fn=predict_masks,
            inputs=[upload_image_input.image, upload_image_input.points],
            outputs=image_output,
            show_progress=True,
        )
    
    return sam2_mask_tab