import subprocess import os import sys import subprocess def run(cmd, cwd=None): print(f"▶ {cmd}") subprocess.check_call(cmd, shell=True, cwd=cwd) def setup_deps(): # Use a flag to prevent infinite restarts if os.environ.get("HF_SPACE_BOOTSTRAPPED") == "1": return # Try importing something to check if it's already set up try: import torch import sam2 print("🔧 Dependencies already installed.") return # all good, don't reinstall except ImportError: pass print("🔧 Installing dependencies...") run("pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu") run("pip install -e .", cwd="segment-anything-2") run("pip install --no-deps -r requirements_manual.txt") # Relaunch the script with an env flag to avoid looping print("♻️ Restarting app to apply changes...") os.environ["HF_SPACE_BOOTSTRAPPED"] = "1" os.execv(sys.executable, [sys.executable] + sys.argv) setup_deps() import gradio as gr import numpy as np from PIL import Image import sam_utils import matplotlib.pyplot as plt from io import BytesIO from sam2.sam2_image_predictor import SAM2ImagePredictor # Dummy placeholders for SAM2 functions (replace with real logic) def segment_reference(image, click): # click = [x, y] # Replace this with your SAM2 model's inference logic # Return a binary mask (numpy array with shape [H, W], values 0 or 1) print(f"Segmenting reference at point: {click}") width, height = image.size click = np.array(click) input_label = np.array([1 for _ in range(len(click))]) sam2_img.set_image(image) masks, _, _ = sam2_img.predict( point_coords=click, point_labels=input_label, multimask_output=False, ) return masks def segment_target(target_images, ref_image, ref_mask): target_images = [np.array(target_image) for target_image in target_images] ref_image = np.array(ref_image) state = sam_utils.load_masks(sam2_vid, target_images, ref_image, ref_mask) out = sam_utils.propagate_masks(sam2_vid, state)[1:] return [mask['segmentation'] for mask in out] def on_reference_upload(img): global click_coords click_coords = [] # clear the clicks return "Click Info: Cleared (new image uploaded)" def visualize_segmentation(image, masks, target_images, target_masks): # Visualize the segmentation result num_tgt = len(target_images) fig, ax = plt.subplots(2, num_tgt, figsize=(6*num_tgt, 12)) if num_tgt == 1: ax = np.expand_dims(ax, axis=1) ax[0][0].imshow(image.convert("L"), cmap='gray') for i, mask in enumerate(masks): sam_utils.show_mask(mask, ax[0][0], obj_id=i, alpha=0.75) ax[0][0].axis('off') ax[0][0].set_title("Reference Image with Expert Segmentation") for i in range(1, num_tgt): # set the rest to empty ax[0][i].axis('off') for i in range(num_tgt): ax[1][i].imshow(target_images[i].convert("L"), cmap='gray') for j, mask in enumerate(target_masks[i]): sam_utils.show_mask(mask, ax[1][i], obj_id=j, alpha=0.75) ax[1][i].axis('off') ax[1][i].set_title("Target Image with Inferred Segmentation") # save it to buffer plt.tight_layout() buf = BytesIO() plt.savefig(buf, format='png') buf.seek(0) vis = Image.open(buf).copy() plt.close(fig) buf.close() return vis # Store click coords globally (can be improved with state) click_coords = [] def record_click(img, evt: gr.SelectData): global click_coords click_coords.append([evt.index[0], evt.index[1]]) return f"Clicked at: {click_coords}" def generate(reference_image, target_images): global click_coords if not click_coords: return None, "Click on the reference image first!" target_images = [Image.open(f.name).convert("RGB").resize((1024,1024)) for f in target_images] ref_mask = segment_reference(reference_image, click_coords) tgt_masks = segment_target(target_images, reference_image, ref_mask) vis = visualize_segmentation(reference_image, ref_mask, target_images, tgt_masks) # clear the clicks click_coords = [] return vis, "Done!" with gr.Blocks() as demo: gr.Markdown("### SST Demo: Label-Efficient Trait Segmentation") with gr.Row(): reference_img = gr.Image(type="pil", label="Reference Image") target_img = gr.File(file_types=["image"], file_count="multiple", label="Target Images") click_info = gr.Textbox(label="Click Info") generate_btn = gr.Button("Generate") output_mask = gr.Image(type="pil", label="Generated Mask") reference_img.select(fn=record_click, inputs=[reference_img], outputs=[click_info]) reference_img.change(fn=on_reference_upload, inputs=[reference_img], outputs=[click_info]) generate_btn.click(fn=generate, inputs=[reference_img, target_img], outputs=[output_mask, click_info]) global sam2_img sam2_img = sam_utils.load_SAM2(ckpt_path="checkpoints/sam2_hiera_small.pt", model_cfg_path="checkpoints/sam2_hiera_s.yaml") sam2_img = SAM2ImagePredictor(sam2_img) global sam2_vid sam2_vid = sam_utils.build_sam2_predictor(checkpoint="checkpoints/sam2_hiera_small.pt", model_cfg="checkpoints/sam2_hiera_s.yaml") demo.launch()