SST / app.py
Daniel-F's picture
support for multi-iamge
7de04d2
import subprocess
import os
import sys
import subprocess
def run(cmd, cwd=None):
print(f"▶ {cmd}")
subprocess.check_call(cmd, shell=True, cwd=cwd)
def setup_deps():
# Use a flag to prevent infinite restarts
if os.environ.get("HF_SPACE_BOOTSTRAPPED") == "1":
return
# Try importing something to check if it's already set up
try:
import torch
import sam2
print("🔧 Dependencies already installed.")
return # all good, don't reinstall
except ImportError:
pass
print("🔧 Installing dependencies...")
run("pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu")
run("pip install -e .", cwd="segment-anything-2")
run("pip install --no-deps -r requirements_manual.txt")
# Relaunch the script with an env flag to avoid looping
print("♻️ Restarting app to apply changes...")
os.environ["HF_SPACE_BOOTSTRAPPED"] = "1"
os.execv(sys.executable, [sys.executable] + sys.argv)
setup_deps()
import gradio as gr
import numpy as np
from PIL import Image
import sam_utils
import matplotlib.pyplot as plt
from io import BytesIO
from sam2.sam2_image_predictor import SAM2ImagePredictor
# Dummy placeholders for SAM2 functions (replace with real logic)
def segment_reference(image, click):
# click = [x, y]
# Replace this with your SAM2 model's inference logic
# Return a binary mask (numpy array with shape [H, W], values 0 or 1)
print(f"Segmenting reference at point: {click}")
width, height = image.size
click = np.array(click)
input_label = np.array([1 for _ in range(len(click))])
sam2_img.set_image(image)
masks, _, _ = sam2_img.predict(
point_coords=click,
point_labels=input_label,
multimask_output=False,
)
return masks
def segment_target(target_images, ref_image, ref_mask):
target_images = [np.array(target_image) for target_image in target_images]
ref_image = np.array(ref_image)
state = sam_utils.load_masks(sam2_vid, target_images, ref_image, ref_mask)
out = sam_utils.propagate_masks(sam2_vid, state)[1:]
return [mask['segmentation'] for mask in out]
def on_reference_upload(img):
global click_coords
click_coords = [] # clear the clicks
return "Click Info: Cleared (new image uploaded)"
def visualize_segmentation(image, masks, target_images, target_masks):
# Visualize the segmentation result
num_tgt = len(target_images)
fig, ax = plt.subplots(2, num_tgt, figsize=(6*num_tgt, 12))
if num_tgt == 1:
ax = np.expand_dims(ax, axis=1)
ax[0][0].imshow(image.convert("L"), cmap='gray')
for i, mask in enumerate(masks):
sam_utils.show_mask(mask, ax[0][0], obj_id=i, alpha=0.75)
ax[0][0].axis('off')
ax[0][0].set_title("Reference Image with Expert Segmentation")
for i in range(1, num_tgt):
# set the rest to empty
ax[0][i].axis('off')
for i in range(num_tgt):
ax[1][i].imshow(target_images[i].convert("L"), cmap='gray')
for j, mask in enumerate(target_masks[i]):
sam_utils.show_mask(mask, ax[1][i], obj_id=j, alpha=0.75)
ax[1][i].axis('off')
ax[1][i].set_title("Target Image with Inferred Segmentation")
# save it to buffer
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
vis = Image.open(buf).copy()
plt.close(fig)
buf.close()
return vis
# Store click coords globally (can be improved with state)
click_coords = []
def record_click(img, evt: gr.SelectData):
global click_coords
click_coords.append([evt.index[0], evt.index[1]])
return f"Clicked at: {click_coords}"
def generate(reference_image, target_images):
global click_coords
if not click_coords:
return None, "Click on the reference image first!"
target_images = [Image.open(f.name).convert("RGB").resize((1024,1024)) for f in target_images]
ref_mask = segment_reference(reference_image, click_coords)
tgt_masks = segment_target(target_images, reference_image, ref_mask)
vis = visualize_segmentation(reference_image, ref_mask, target_images, tgt_masks)
# clear the clicks
click_coords = []
return vis, "Done!"
with gr.Blocks() as demo:
gr.Markdown("### SST Demo: Label-Efficient Trait Segmentation")
with gr.Row():
reference_img = gr.Image(type="pil", label="Reference Image")
target_img = gr.File(file_types=["image"], file_count="multiple", label="Target Images")
click_info = gr.Textbox(label="Click Info")
generate_btn = gr.Button("Generate")
output_mask = gr.Image(type="pil", label="Generated Mask")
reference_img.select(fn=record_click, inputs=[reference_img], outputs=[click_info])
reference_img.change(fn=on_reference_upload, inputs=[reference_img], outputs=[click_info])
generate_btn.click(fn=generate, inputs=[reference_img, target_img], outputs=[output_mask, click_info])
global sam2_img
sam2_img = sam_utils.load_SAM2(ckpt_path="checkpoints/sam2_hiera_small.pt", model_cfg_path="checkpoints/sam2_hiera_s.yaml")
sam2_img = SAM2ImagePredictor(sam2_img)
global sam2_vid
sam2_vid = sam_utils.build_sam2_predictor(checkpoint="checkpoints/sam2_hiera_small.pt", model_cfg="checkpoints/sam2_hiera_s.yaml")
demo.launch()