Daniel-F commited on
Commit
7de04d2
·
1 Parent(s): 4c6b11a

support for multi-iamge

Browse files
Files changed (1) hide show
  1. app.py +33 -20
app.py CHANGED
@@ -63,31 +63,38 @@ def segment_reference(image, click):
63
 
64
  return masks
65
 
66
- def segment_target(target_image, ref_image, ref_mask):
67
- target_image = np.array(target_image)
68
  ref_image = np.array(ref_image)
69
- state = sam_utils.load_masks(sam2_vid, [target_image], ref_image, ref_mask)
70
- out = sam_utils.propagate_masks(sam2_vid, state)[-1]['segmentation']
71
- return out # Just for placeholder demo
72
 
73
  def on_reference_upload(img):
74
  global click_coords
75
  click_coords = [] # clear the clicks
76
  return "Click Info: Cleared (new image uploaded)"
77
 
78
- def visualize_segmentation(image, masks, target_image, target_mask):
79
  # Visualize the segmentation result
80
- fig, ax = plt.subplots(1, 2, figsize=(12, 6))
81
- ax[0].imshow(image.convert("L"), cmap='gray')
 
 
 
82
  for i, mask in enumerate(masks):
83
- sam_utils.show_mask(mask, ax[0], obj_id=i, alpha=0.75)
84
- ax[0].axis('off')
85
- ax[0].set_title("Reference Image with Expert Segmentation")
86
- ax[1].imshow(target_image.convert("L"), cmap='gray')
87
- for i, mask in enumerate(target_mask):
88
- sam_utils.show_mask(mask, ax[1], obj_id=i, alpha=0.75)
89
- ax[1].axis('off')
90
- ax[1].set_title("Target Image with Inferred Segmentation")
 
 
 
 
91
  # save it to buffer
92
  plt.tight_layout()
93
  buf = BytesIO()
@@ -106,12 +113,18 @@ def record_click(img, evt: gr.SelectData):
106
  click_coords.append([evt.index[0], evt.index[1]])
107
  return f"Clicked at: {click_coords}"
108
 
109
- def generate(reference_image, target_image):
 
110
  if not click_coords:
111
  return None, "Click on the reference image first!"
 
 
 
112
  ref_mask = segment_reference(reference_image, click_coords)
113
- tgt_mask = segment_target(target_image, reference_image, ref_mask)
114
- vis = visualize_segmentation(reference_image, ref_mask, target_image, tgt_mask)
 
 
115
  return vis, "Done!"
116
 
117
  with gr.Blocks() as demo:
@@ -119,7 +132,7 @@ with gr.Blocks() as demo:
119
 
120
  with gr.Row():
121
  reference_img = gr.Image(type="pil", label="Reference Image")
122
- target_img = gr.Image(type="pil", label="Target Image")
123
 
124
  click_info = gr.Textbox(label="Click Info")
125
  generate_btn = gr.Button("Generate")
 
63
 
64
  return masks
65
 
66
+ def segment_target(target_images, ref_image, ref_mask):
67
+ target_images = [np.array(target_image) for target_image in target_images]
68
  ref_image = np.array(ref_image)
69
+ state = sam_utils.load_masks(sam2_vid, target_images, ref_image, ref_mask)
70
+ out = sam_utils.propagate_masks(sam2_vid, state)[1:]
71
+ return [mask['segmentation'] for mask in out]
72
 
73
  def on_reference_upload(img):
74
  global click_coords
75
  click_coords = [] # clear the clicks
76
  return "Click Info: Cleared (new image uploaded)"
77
 
78
+ def visualize_segmentation(image, masks, target_images, target_masks):
79
  # Visualize the segmentation result
80
+ num_tgt = len(target_images)
81
+ fig, ax = plt.subplots(2, num_tgt, figsize=(6*num_tgt, 12))
82
+ if num_tgt == 1:
83
+ ax = np.expand_dims(ax, axis=1)
84
+ ax[0][0].imshow(image.convert("L"), cmap='gray')
85
  for i, mask in enumerate(masks):
86
+ sam_utils.show_mask(mask, ax[0][0], obj_id=i, alpha=0.75)
87
+ ax[0][0].axis('off')
88
+ ax[0][0].set_title("Reference Image with Expert Segmentation")
89
+ for i in range(1, num_tgt):
90
+ # set the rest to empty
91
+ ax[0][i].axis('off')
92
+ for i in range(num_tgt):
93
+ ax[1][i].imshow(target_images[i].convert("L"), cmap='gray')
94
+ for j, mask in enumerate(target_masks[i]):
95
+ sam_utils.show_mask(mask, ax[1][i], obj_id=j, alpha=0.75)
96
+ ax[1][i].axis('off')
97
+ ax[1][i].set_title("Target Image with Inferred Segmentation")
98
  # save it to buffer
99
  plt.tight_layout()
100
  buf = BytesIO()
 
113
  click_coords.append([evt.index[0], evt.index[1]])
114
  return f"Clicked at: {click_coords}"
115
 
116
+ def generate(reference_image, target_images):
117
+ global click_coords
118
  if not click_coords:
119
  return None, "Click on the reference image first!"
120
+
121
+ target_images = [Image.open(f.name).convert("RGB").resize((1024,1024)) for f in target_images]
122
+
123
  ref_mask = segment_reference(reference_image, click_coords)
124
+ tgt_masks = segment_target(target_images, reference_image, ref_mask)
125
+ vis = visualize_segmentation(reference_image, ref_mask, target_images, tgt_masks)
126
+ # clear the clicks
127
+ click_coords = []
128
  return vis, "Done!"
129
 
130
  with gr.Blocks() as demo:
 
132
 
133
  with gr.Row():
134
  reference_img = gr.Image(type="pil", label="Reference Image")
135
+ target_img = gr.File(file_types=["image"], file_count="multiple", label="Target Images")
136
 
137
  click_info = gr.Textbox(label="Click Info")
138
  generate_btn = gr.Button("Generate")