SST / app.py
Daniel-F's picture
demo v1
ccb7bbe
raw
history blame
3.54 kB
import gradio as gr
import numpy as np
from PIL import Image
import sam_utils
import matplotlib.pyplot as plt
from io import BytesIO
from sam2.sam2_image_predictor import SAM2ImagePredictor
# Dummy placeholders for SAM2 functions (replace with real logic)
def segment_reference(image, click):
# click = [x, y]
# Replace this with your SAM2 model's inference logic
# Return a binary mask (numpy array with shape [H, W], values 0 or 1)
print(f"Segmenting reference at point: {click}")
width, height = image.size
click = np.array(click)
input_label = np.array([1 for _ in range(len(click))])
sam2_img.set_image(image)
masks, _, _ = sam2_img.predict(
point_coords=click,
point_labels=input_label,
multimask_output=False,
)
return masks
def segment_target(target_image, ref_image, ref_mask):
target_image = np.array(target_image)
ref_image = np.array(ref_image)
state = sam_utils.load_masks(sam2_vid, [target_image], ref_image, ref_mask)
out = sam_utils.propagate_masks(sam2_vid, state)[-1]['segmentation']
return out # Just for placeholder demo
def visualize_segmentation(image, masks, target_image, target_mask):
# Visualize the segmentation result
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(image.convert("L"), cmap='gray')
for i, mask in enumerate(masks):
sam_utils.show_mask(mask, ax[0], obj_id=i, alpha=0.75)
ax[0].axis('off')
ax[0].set_title("Reference Image with Expert Segmentation")
ax[1].imshow(target_image.convert("L"), cmap='gray')
for i, mask in enumerate(target_mask):
sam_utils.show_mask(mask, ax[1], obj_id=i, alpha=0.75)
ax[1].axis('off')
ax[1].set_title("Target Image with Inferred Segmentation")
# save it to buffer
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
vis = Image.open(buf).copy()
plt.close(fig)
buf.close()
return vis
# Store click coords globally (can be improved with state)
click_coords = []
def record_click(img, evt: gr.SelectData):
global click_coords
click_coords.append([evt.index[0], evt.index[1]])
return f"Clicked at: {click_coords}"
def generate(reference_image, target_image):
if not click_coords:
return None, "Click on the reference image first!"
ref_mask = segment_reference(reference_image, click_coords)
tgt_mask = segment_target(target_image, reference_image, ref_mask)
vis = visualize_segmentation(reference_image, ref_mask, target_image, tgt_mask)
return vis, "Done!"
with gr.Blocks() as demo:
gr.Markdown("### SST Demo: Label-Efficient Trait Segmentation")
with gr.Row():
reference_img = gr.Image(type="pil", label="Reference Image")
target_img = gr.Image(type="pil", label="Target Image")
click_info = gr.Textbox(label="Click Info")
generate_btn = gr.Button("Generate")
output_mask = gr.Image(type="pil", label="Generated Mask")
reference_img.select(fn=record_click, inputs=[reference_img], outputs=[click_info])
generate_btn.click(fn=generate, inputs=[reference_img, target_img], outputs=[output_mask, click_info])
global sam2_img
sam2_img = sam_utils.load_SAM2(ckpt_path="checkpoints/sam2_hiera_small.pt", model_cfg_path="checkpoints/sam2_hiera_s.yaml")
sam2_img = SAM2ImagePredictor(sam2_img)
global sam2_vid
sam2_vid = sam_utils.build_sam2_predictor(checkpoint="checkpoints/sam2_hiera_small.pt", model_cfg="checkpoints/sam2_hiera_s.yaml")
demo.launch()