File size: 3,542 Bytes
ccb7bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import numpy as np
from PIL import Image
import sam_utils
import matplotlib.pyplot as plt
from io import BytesIO

from sam2.sam2_image_predictor import SAM2ImagePredictor

# Dummy placeholders for SAM2 functions (replace with real logic)
def segment_reference(image, click):
    # click = [x, y]
    # Replace this with your SAM2 model's inference logic
    # Return a binary mask (numpy array with shape [H, W], values 0 or 1)
    print(f"Segmenting reference at point: {click}")
    width, height = image.size
    click = np.array(click)
    input_label = np.array([1 for _ in range(len(click))])
    sam2_img.set_image(image)

    masks, _, _ = sam2_img.predict(
        point_coords=click,
        point_labels=input_label,
        multimask_output=False,
    )

    return masks

def segment_target(target_image, ref_image, ref_mask):
    target_image = np.array(target_image)
    ref_image = np.array(ref_image)
    state = sam_utils.load_masks(sam2_vid, [target_image], ref_image, ref_mask)
    out = sam_utils.propagate_masks(sam2_vid, state)[-1]['segmentation']
    return out  # Just for placeholder demo

def visualize_segmentation(image, masks, target_image, target_mask):
    # Visualize the segmentation result
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].imshow(image.convert("L"), cmap='gray')
    for i, mask in enumerate(masks):
        sam_utils.show_mask(mask, ax[0], obj_id=i, alpha=0.75)
    ax[0].axis('off')
    ax[0].set_title("Reference Image with Expert Segmentation")
    ax[1].imshow(target_image.convert("L"), cmap='gray')
    for i, mask in enumerate(target_mask):
        sam_utils.show_mask(mask, ax[1], obj_id=i, alpha=0.75)
    ax[1].axis('off')
    ax[1].set_title("Target Image with Inferred Segmentation")
    # save it to buffer
    plt.tight_layout()
    buf = BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    vis = Image.open(buf).copy()
    plt.close(fig)
    buf.close()
    return vis

# Store click coords globally (can be improved with state)
click_coords = []

def record_click(img, evt: gr.SelectData):
    global click_coords
    click_coords.append([evt.index[0], evt.index[1]])
    return f"Clicked at: {click_coords}"

def generate(reference_image, target_image):
    if not click_coords:
        return None, "Click on the reference image first!"
    ref_mask = segment_reference(reference_image, click_coords)
    tgt_mask = segment_target(target_image, reference_image, ref_mask)
    vis = visualize_segmentation(reference_image, ref_mask, target_image, tgt_mask)
    return vis, "Done!"

with gr.Blocks() as demo:
    gr.Markdown("### SST Demo: Label-Efficient Trait Segmentation")
    
    with gr.Row():
        reference_img = gr.Image(type="pil", label="Reference Image")
        target_img = gr.Image(type="pil", label="Target Image")
    
    click_info = gr.Textbox(label="Click Info")
    generate_btn = gr.Button("Generate")
    output_mask = gr.Image(type="pil", label="Generated Mask")

    reference_img.select(fn=record_click, inputs=[reference_img], outputs=[click_info])
    generate_btn.click(fn=generate, inputs=[reference_img, target_img], outputs=[output_mask, click_info])

global sam2_img
sam2_img = sam_utils.load_SAM2(ckpt_path="checkpoints/sam2_hiera_small.pt", model_cfg_path="checkpoints/sam2_hiera_s.yaml")
sam2_img = SAM2ImagePredictor(sam2_img)
global sam2_vid
sam2_vid = sam_utils.build_sam2_predictor(checkpoint="checkpoints/sam2_hiera_small.pt", model_cfg="checkpoints/sam2_hiera_s.yaml")
demo.launch()