File size: 5,350 Bytes
f076cb6 c8757ee f076cb6 c8757ee f076cb6 c8757ee f076cb6 c8757ee f076cb6 c8757ee f076cb6 ccb7bbe f076cb6 ccb7bbe 7de04d2 ccb7bbe 7de04d2 ccb7bbe 187d444 7de04d2 ccb7bbe 7de04d2 ccb7bbe 7de04d2 ccb7bbe 7de04d2 ccb7bbe 7de04d2 ccb7bbe 7de04d2 ccb7bbe 7de04d2 ccb7bbe 4c6b11a ccb7bbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import subprocess
import os
import sys
import subprocess
def run(cmd, cwd=None):
print(f"▶ {cmd}")
subprocess.check_call(cmd, shell=True, cwd=cwd)
def setup_deps():
# Use a flag to prevent infinite restarts
if os.environ.get("HF_SPACE_BOOTSTRAPPED") == "1":
return
# Try importing something to check if it's already set up
try:
import torch
import sam2
print("🔧 Dependencies already installed.")
return # all good, don't reinstall
except ImportError:
pass
print("🔧 Installing dependencies...")
run("pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu")
run("pip install -e .", cwd="segment-anything-2")
run("pip install --no-deps -r requirements_manual.txt")
# Relaunch the script with an env flag to avoid looping
print("♻️ Restarting app to apply changes...")
os.environ["HF_SPACE_BOOTSTRAPPED"] = "1"
os.execv(sys.executable, [sys.executable] + sys.argv)
setup_deps()
import gradio as gr
import numpy as np
from PIL import Image
import sam_utils
import matplotlib.pyplot as plt
from io import BytesIO
from sam2.sam2_image_predictor import SAM2ImagePredictor
# Dummy placeholders for SAM2 functions (replace with real logic)
def segment_reference(image, click):
# click = [x, y]
# Replace this with your SAM2 model's inference logic
# Return a binary mask (numpy array with shape [H, W], values 0 or 1)
print(f"Segmenting reference at point: {click}")
width, height = image.size
click = np.array(click)
input_label = np.array([1 for _ in range(len(click))])
sam2_img.set_image(image)
masks, _, _ = sam2_img.predict(
point_coords=click,
point_labels=input_label,
multimask_output=False,
)
return masks
def segment_target(target_images, ref_image, ref_mask):
target_images = [np.array(target_image) for target_image in target_images]
ref_image = np.array(ref_image)
state = sam_utils.load_masks(sam2_vid, target_images, ref_image, ref_mask)
out = sam_utils.propagate_masks(sam2_vid, state)[1:]
return [mask['segmentation'] for mask in out]
def on_reference_upload(img):
global click_coords
click_coords = [] # clear the clicks
return "Click Info: Cleared (new image uploaded)"
def visualize_segmentation(image, masks, target_images, target_masks):
# Visualize the segmentation result
num_tgt = len(target_images)
fig, ax = plt.subplots(2, num_tgt, figsize=(6*num_tgt, 12))
if num_tgt == 1:
ax = np.expand_dims(ax, axis=1)
ax[0][0].imshow(image.convert("L"), cmap='gray')
for i, mask in enumerate(masks):
sam_utils.show_mask(mask, ax[0][0], obj_id=i, alpha=0.75)
ax[0][0].axis('off')
ax[0][0].set_title("Reference Image with Expert Segmentation")
for i in range(1, num_tgt):
# set the rest to empty
ax[0][i].axis('off')
for i in range(num_tgt):
ax[1][i].imshow(target_images[i].convert("L"), cmap='gray')
for j, mask in enumerate(target_masks[i]):
sam_utils.show_mask(mask, ax[1][i], obj_id=j, alpha=0.75)
ax[1][i].axis('off')
ax[1][i].set_title("Target Image with Inferred Segmentation")
# save it to buffer
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
vis = Image.open(buf).copy()
plt.close(fig)
buf.close()
return vis
# Store click coords globally (can be improved with state)
click_coords = []
def record_click(img, evt: gr.SelectData):
global click_coords
click_coords.append([evt.index[0], evt.index[1]])
return f"Clicked at: {click_coords}"
def generate(reference_image, target_images):
global click_coords
if not click_coords:
return None, "Click on the reference image first!"
target_images = [Image.open(f.name).convert("RGB").resize((1024,1024)) for f in target_images]
ref_mask = segment_reference(reference_image, click_coords)
tgt_masks = segment_target(target_images, reference_image, ref_mask)
vis = visualize_segmentation(reference_image, ref_mask, target_images, tgt_masks)
# clear the clicks
click_coords = []
return vis, "Done!"
with gr.Blocks() as demo:
gr.Markdown("### SST Demo: Label-Efficient Trait Segmentation")
with gr.Row():
reference_img = gr.Image(type="pil", label="Reference Image")
target_img = gr.File(file_types=["image"], file_count="multiple", label="Target Images")
click_info = gr.Textbox(label="Click Info")
generate_btn = gr.Button("Generate")
output_mask = gr.Image(type="pil", label="Generated Mask")
reference_img.select(fn=record_click, inputs=[reference_img], outputs=[click_info])
reference_img.change(fn=on_reference_upload, inputs=[reference_img], outputs=[click_info])
generate_btn.click(fn=generate, inputs=[reference_img, target_img], outputs=[output_mask, click_info])
global sam2_img
sam2_img = sam_utils.load_SAM2(ckpt_path="checkpoints/sam2_hiera_small.pt", model_cfg_path="checkpoints/sam2_hiera_s.yaml")
sam2_img = SAM2ImagePredictor(sam2_img)
global sam2_vid
sam2_vid = sam_utils.build_sam2_predictor(checkpoint="checkpoints/sam2_hiera_small.pt", model_cfg="checkpoints/sam2_hiera_s.yaml")
demo.launch()
|