|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import sam_utils |
|
import matplotlib.pyplot as plt |
|
from io import BytesIO |
|
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
|
|
def segment_reference(image, click): |
|
|
|
|
|
|
|
print(f"Segmenting reference at point: {click}") |
|
width, height = image.size |
|
click = np.array(click) |
|
input_label = np.array([1 for _ in range(len(click))]) |
|
sam2_img.set_image(image) |
|
|
|
masks, _, _ = sam2_img.predict( |
|
point_coords=click, |
|
point_labels=input_label, |
|
multimask_output=False, |
|
) |
|
|
|
return masks |
|
|
|
def segment_target(target_image, ref_image, ref_mask): |
|
target_image = np.array(target_image) |
|
ref_image = np.array(ref_image) |
|
state = sam_utils.load_masks(sam2_vid, [target_image], ref_image, ref_mask) |
|
out = sam_utils.propagate_masks(sam2_vid, state)[-1]['segmentation'] |
|
return out |
|
|
|
def visualize_segmentation(image, masks, target_image, target_mask): |
|
|
|
fig, ax = plt.subplots(1, 2, figsize=(12, 6)) |
|
ax[0].imshow(image.convert("L"), cmap='gray') |
|
for i, mask in enumerate(masks): |
|
sam_utils.show_mask(mask, ax[0], obj_id=i, alpha=0.75) |
|
ax[0].axis('off') |
|
ax[0].set_title("Reference Image with Expert Segmentation") |
|
ax[1].imshow(target_image.convert("L"), cmap='gray') |
|
for i, mask in enumerate(target_mask): |
|
sam_utils.show_mask(mask, ax[1], obj_id=i, alpha=0.75) |
|
ax[1].axis('off') |
|
ax[1].set_title("Target Image with Inferred Segmentation") |
|
|
|
plt.tight_layout() |
|
buf = BytesIO() |
|
plt.savefig(buf, format='png') |
|
buf.seek(0) |
|
vis = Image.open(buf).copy() |
|
plt.close(fig) |
|
buf.close() |
|
return vis |
|
|
|
|
|
click_coords = [] |
|
|
|
def record_click(img, evt: gr.SelectData): |
|
global click_coords |
|
click_coords.append([evt.index[0], evt.index[1]]) |
|
return f"Clicked at: {click_coords}" |
|
|
|
def generate(reference_image, target_image): |
|
if not click_coords: |
|
return None, "Click on the reference image first!" |
|
ref_mask = segment_reference(reference_image, click_coords) |
|
tgt_mask = segment_target(target_image, reference_image, ref_mask) |
|
vis = visualize_segmentation(reference_image, ref_mask, target_image, tgt_mask) |
|
return vis, "Done!" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("### SST Demo: Label-Efficient Trait Segmentation") |
|
|
|
with gr.Row(): |
|
reference_img = gr.Image(type="pil", label="Reference Image") |
|
target_img = gr.Image(type="pil", label="Target Image") |
|
|
|
click_info = gr.Textbox(label="Click Info") |
|
generate_btn = gr.Button("Generate") |
|
output_mask = gr.Image(type="pil", label="Generated Mask") |
|
|
|
reference_img.select(fn=record_click, inputs=[reference_img], outputs=[click_info]) |
|
generate_btn.click(fn=generate, inputs=[reference_img, target_img], outputs=[output_mask, click_info]) |
|
|
|
global sam2_img |
|
sam2_img = sam_utils.load_SAM2(ckpt_path="checkpoints/sam2_hiera_small.pt", model_cfg_path="checkpoints/sam2_hiera_s.yaml") |
|
sam2_img = SAM2ImagePredictor(sam2_img) |
|
global sam2_vid |
|
sam2_vid = sam_utils.build_sam2_predictor(checkpoint="checkpoints/sam2_hiera_small.pt", model_cfg="checkpoints/sam2_hiera_s.yaml") |
|
demo.launch() |
|
|