LPX55 commited on
Commit
0d5a836
·
1 Parent(s): 5a9b472
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -65,10 +65,8 @@ def load_default_pipeline():
65
  @spaces.GPU()
66
  def predict_masks(image, points):
67
  """Predict a single mask from the image based on selected points."""
68
-
69
  if not points:
70
  return image # Return the original image if no points are selected
71
-
72
  PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
73
 
74
  image_np = np.array(image)
@@ -535,6 +533,8 @@ with gr.Blocks(css=css, fill_height=True) as demo:
535
  )
536
  with gr.Row():
537
  with gr.Column():
 
 
538
  upload_image_input = ImagePrompter(show_label=False)
539
  with gr.Column():
540
  image_output = gr.Image(label="Segmented Image", type="pil", height=400)
@@ -562,6 +562,11 @@ with gr.Blocks(css=css, fill_height=True) as demo:
562
  inputs=None,
563
  outputs=load_default_message,
564
  )
 
 
 
 
 
565
  target_ratio.change(
566
  fn=preload_presets,
567
  inputs=[target_ratio, width_slider, height_slider],
 
65
  @spaces.GPU()
66
  def predict_masks(image, points):
67
  """Predict a single mask from the image based on selected points."""
 
68
  if not points:
69
  return image # Return the original image if no points are selected
 
70
  PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
71
 
72
  image_np = np.array(image)
 
533
  )
534
  with gr.Row():
535
  with gr.Column():
536
+ image_input = gr.State()
537
+ # Input: ImagePrompter for uploaded image
538
  upload_image_input = ImagePrompter(show_label=False)
539
  with gr.Column():
540
  image_output = gr.Image(label="Segmented Image", type="pil", height=400)
 
562
  inputs=None,
563
  outputs=load_default_message,
564
  )
565
+
566
+ upload_image_input.change(
567
+ fn=lambda img: img, inputs=upload_image_input, outputs=image_input
568
+ )
569
+
570
  target_ratio.change(
571
  fn=preload_presets,
572
  inputs=[target_ratio, width_slider, height_slider],