|
import subprocess |
|
|
|
import os |
|
import sys |
|
import subprocess |
|
|
|
def run(cmd, cwd=None): |
|
print(f"▶ {cmd}") |
|
subprocess.check_call(cmd, shell=True, cwd=cwd) |
|
|
|
def setup_deps(): |
|
|
|
if os.environ.get("HF_SPACE_BOOTSTRAPPED") == "1": |
|
return |
|
|
|
|
|
try: |
|
import torch |
|
import sam2 |
|
print("🔧 Dependencies already installed.") |
|
return |
|
except ImportError: |
|
pass |
|
|
|
print("🔧 Installing dependencies...") |
|
run("pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu") |
|
run("pip install -e .", cwd="segment-anything-2") |
|
run("pip install --no-deps -r requirements_manual.txt") |
|
|
|
|
|
print("♻️ Restarting app to apply changes...") |
|
os.environ["HF_SPACE_BOOTSTRAPPED"] = "1" |
|
os.execv(sys.executable, [sys.executable] + sys.argv) |
|
|
|
setup_deps() |
|
|
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import sam_utils |
|
import matplotlib.pyplot as plt |
|
from io import BytesIO |
|
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
|
|
|
|
def segment_reference(image, click): |
|
|
|
|
|
|
|
print(f"Segmenting reference at point: {click}") |
|
width, height = image.size |
|
click = np.array(click) |
|
input_label = np.array([1 for _ in range(len(click))]) |
|
sam2_img.set_image(image) |
|
|
|
masks, _, _ = sam2_img.predict( |
|
point_coords=click, |
|
point_labels=input_label, |
|
multimask_output=False, |
|
) |
|
|
|
return masks |
|
|
|
def segment_target(target_images, ref_image, ref_mask): |
|
target_images = [np.array(target_image) for target_image in target_images] |
|
ref_image = np.array(ref_image) |
|
state = sam_utils.load_masks(sam2_vid, target_images, ref_image, ref_mask) |
|
out = sam_utils.propagate_masks(sam2_vid, state)[1:] |
|
return [mask['segmentation'] for mask in out] |
|
|
|
def on_reference_upload(img): |
|
global click_coords |
|
click_coords = [] |
|
return "Click Info: Cleared (new image uploaded)" |
|
|
|
def visualize_segmentation(image, masks, target_images, target_masks): |
|
|
|
num_tgt = len(target_images) |
|
fig, ax = plt.subplots(2, num_tgt, figsize=(6*num_tgt, 12)) |
|
if num_tgt == 1: |
|
ax = np.expand_dims(ax, axis=1) |
|
ax[0][0].imshow(image.convert("L"), cmap='gray') |
|
for i, mask in enumerate(masks): |
|
sam_utils.show_mask(mask, ax[0][0], obj_id=i, alpha=0.75) |
|
ax[0][0].axis('off') |
|
ax[0][0].set_title("Reference Image with Expert Segmentation") |
|
for i in range(1, num_tgt): |
|
|
|
ax[0][i].axis('off') |
|
for i in range(num_tgt): |
|
ax[1][i].imshow(target_images[i].convert("L"), cmap='gray') |
|
for j, mask in enumerate(target_masks[i]): |
|
sam_utils.show_mask(mask, ax[1][i], obj_id=j, alpha=0.75) |
|
ax[1][i].axis('off') |
|
ax[1][i].set_title("Target Image with Inferred Segmentation") |
|
|
|
plt.tight_layout() |
|
buf = BytesIO() |
|
plt.savefig(buf, format='png') |
|
buf.seek(0) |
|
vis = Image.open(buf).copy() |
|
plt.close(fig) |
|
buf.close() |
|
return vis |
|
|
|
|
|
click_coords = [] |
|
|
|
def record_click(img, evt: gr.SelectData): |
|
global click_coords |
|
click_coords.append([evt.index[0], evt.index[1]]) |
|
return f"Clicked at: {click_coords}" |
|
|
|
def generate(reference_image, target_images): |
|
global click_coords |
|
if not click_coords: |
|
return None, "Click on the reference image first!" |
|
|
|
target_images = [Image.open(f.name).convert("RGB").resize((1024,1024)) for f in target_images] |
|
|
|
ref_mask = segment_reference(reference_image, click_coords) |
|
tgt_masks = segment_target(target_images, reference_image, ref_mask) |
|
vis = visualize_segmentation(reference_image, ref_mask, target_images, tgt_masks) |
|
|
|
click_coords = [] |
|
return vis, "Done!" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("### SST Demo: Label-Efficient Trait Segmentation") |
|
|
|
with gr.Row(): |
|
reference_img = gr.Image(type="pil", label="Reference Image") |
|
target_img = gr.File(file_types=["image"], file_count="multiple", label="Target Images") |
|
|
|
click_info = gr.Textbox(label="Click Info") |
|
generate_btn = gr.Button("Generate") |
|
output_mask = gr.Image(type="pil", label="Generated Mask") |
|
|
|
reference_img.select(fn=record_click, inputs=[reference_img], outputs=[click_info]) |
|
reference_img.change(fn=on_reference_upload, inputs=[reference_img], outputs=[click_info]) |
|
generate_btn.click(fn=generate, inputs=[reference_img, target_img], outputs=[output_mask, click_info]) |
|
|
|
global sam2_img |
|
sam2_img = sam_utils.load_SAM2(ckpt_path="checkpoints/sam2_hiera_small.pt", model_cfg_path="checkpoints/sam2_hiera_s.yaml") |
|
sam2_img = SAM2ImagePredictor(sam2_img) |
|
global sam2_vid |
|
sam2_vid = sam_utils.build_sam2_predictor(checkpoint="checkpoints/sam2_hiera_small.pt", model_cfg="checkpoints/sam2_hiera_s.yaml") |
|
demo.launch() |
|
|