| import subprocess |
|
|
| import os |
| import sys |
| import subprocess |
|
|
| def run(cmd, cwd=None): |
| print(f"▶ {cmd}") |
| subprocess.check_call(cmd, shell=True, cwd=cwd) |
|
|
| def setup_deps(): |
| |
| if os.environ.get("HF_SPACE_BOOTSTRAPPED") == "1": |
| return |
|
|
| |
| try: |
| import torch |
| import sam2 |
| print("🔧 Dependencies already installed.") |
| return |
| except ImportError: |
| pass |
|
|
| print("🔧 Installing dependencies...") |
| run("pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu") |
| run("pip install -e .", cwd="segment-anything-2") |
| run("pip install --no-deps -r requirements_manual.txt") |
|
|
| |
| print("♻️ Restarting app to apply changes...") |
| os.environ["HF_SPACE_BOOTSTRAPPED"] = "1" |
| os.execv(sys.executable, [sys.executable] + sys.argv) |
|
|
| setup_deps() |
|
|
| import gradio as gr |
| import numpy as np |
| from PIL import Image |
| import sam_utils |
| import matplotlib.pyplot as plt |
| from io import BytesIO |
|
|
| from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
|
| |
| def segment_reference(image, click): |
| |
| |
| |
| print(f"Segmenting reference at point: {click}") |
| width, height = image.size |
| click = np.array(click) |
| input_label = np.array([1 for _ in range(len(click))]) |
| sam2_img.set_image(image) |
|
|
| masks, _, _ = sam2_img.predict( |
| point_coords=click, |
| point_labels=input_label, |
| multimask_output=False, |
| ) |
|
|
| return masks |
|
|
| def segment_target(target_images, ref_image, ref_mask): |
| target_images = [np.array(target_image) for target_image in target_images] |
| ref_image = np.array(ref_image) |
| state = sam_utils.load_masks(sam2_vid, target_images, ref_image, ref_mask) |
| out = sam_utils.propagate_masks(sam2_vid, state)[1:] |
| return [mask['segmentation'] for mask in out] |
|
|
| def on_reference_upload(img): |
| global click_coords |
| click_coords = [] |
| return "Click Info: Cleared (new image uploaded)" |
|
|
| def visualize_segmentation(image, masks, target_images, target_masks): |
| |
| num_tgt = len(target_images) |
| fig, ax = plt.subplots(2, num_tgt, figsize=(6*num_tgt, 12)) |
| if num_tgt == 1: |
| ax = np.expand_dims(ax, axis=1) |
| ax[0][0].imshow(image.convert("L"), cmap='gray') |
| for i, mask in enumerate(masks): |
| sam_utils.show_mask(mask, ax[0][0], obj_id=i, alpha=0.75) |
| ax[0][0].axis('off') |
| ax[0][0].set_title("Reference Image with Expert Segmentation") |
| for i in range(1, num_tgt): |
| |
| ax[0][i].axis('off') |
| for i in range(num_tgt): |
| ax[1][i].imshow(target_images[i].convert("L"), cmap='gray') |
| for j, mask in enumerate(target_masks[i]): |
| sam_utils.show_mask(mask, ax[1][i], obj_id=j, alpha=0.75) |
| ax[1][i].axis('off') |
| ax[1][i].set_title("Target Image with Inferred Segmentation") |
| |
| plt.tight_layout() |
| buf = BytesIO() |
| plt.savefig(buf, format='png') |
| buf.seek(0) |
| vis = Image.open(buf).copy() |
| plt.close(fig) |
| buf.close() |
| return vis |
|
|
| |
| click_coords = [] |
|
|
| def record_click(img, evt: gr.SelectData): |
| global click_coords |
| click_coords.append([evt.index[0], evt.index[1]]) |
| return f"Clicked at: {click_coords}" |
|
|
| def generate(reference_image, target_images): |
| global click_coords |
| if not click_coords: |
| return None, "Click on the reference image first!" |
| |
| target_images = [Image.open(f.name).convert("RGB").resize((1024,1024)) for f in target_images] |
|
|
| ref_mask = segment_reference(reference_image, click_coords) |
| tgt_masks = segment_target(target_images, reference_image, ref_mask) |
| vis = visualize_segmentation(reference_image, ref_mask, target_images, tgt_masks) |
| |
| click_coords = [] |
| return vis, "Done!" |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("### SST Demo: Label-Efficient Trait Segmentation") |
| |
| with gr.Row(): |
| reference_img = gr.Image(type="pil", label="Reference Image") |
| target_img = gr.File(file_types=["image"], file_count="multiple", label="Target Images") |
| |
| click_info = gr.Textbox(label="Click Info") |
| generate_btn = gr.Button("Generate") |
| output_mask = gr.Image(type="pil", label="Generated Mask") |
|
|
| reference_img.select(fn=record_click, inputs=[reference_img], outputs=[click_info]) |
| reference_img.change(fn=on_reference_upload, inputs=[reference_img], outputs=[click_info]) |
| generate_btn.click(fn=generate, inputs=[reference_img, target_img], outputs=[output_mask, click_info]) |
|
|
| global sam2_img |
| sam2_img = sam_utils.load_SAM2(ckpt_path="checkpoints/sam2_hiera_small.pt", model_cfg_path="checkpoints/sam2_hiera_s.yaml") |
| sam2_img = SAM2ImagePredictor(sam2_img) |
| global sam2_vid |
| sam2_vid = sam_utils.build_sam2_predictor(checkpoint="checkpoints/sam2_hiera_small.pt", model_cfg="checkpoints/sam2_hiera_s.yaml") |
| demo.launch() |
|
|