Sm0kyWu commited on
Commit
f9aef86
Β·
verified Β·
1 Parent(s): bfb6f69

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -63,15 +63,16 @@ def run_sam(image, predictor, selected_points):
63
  # input_points = np.array([[210, 300]])
64
  # input_labels = np.array([1])
65
  masks, _, _ = predictor.predict(
66
- point_coords=input_points,
67
- point_labels=input_labels,
68
  multimask_output=False, # 单对豑输出
69
  )
70
- # print(masks.shape, np.unique(masks))
71
  best_mask = masks[0]
72
- # best_mask = masks[np.argmax(scores)]
73
- # print(np.unique(best_mask), best_mask.shape)
74
- # print(type(best_mask), best_mask.dtype)
 
 
75
  visible_mask = 255 * best_mask.astype(np.uint8)
76
  return visible_mask
77
 
@@ -396,7 +397,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
396
 
397
  with gr.Row():
398
  with gr.Column():
399
- input_image = gr.Image(type="numpy", label='Input Occlusion Image', height=300)
400
  with gr.Row():
401
  x_input = gr.Number(label="X Coordinate", value=0)
402
  y_input = gr.Number(label="Y Coordinate", value=0)
@@ -443,7 +444,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
443
  input_image.upload(
444
  reset_image,
445
  [predictor, input_image],
446
- [predictor, visible_mask, occlusion_mask, original_image]
447
  )
448
  see_button.click(
449
  see_point,
 
63
  # input_points = np.array([[210, 300]])
64
  # input_labels = np.array([1])
65
  masks, _, _ = predictor.predict(
66
+ point_coords=np.array(input_points),
67
+ point_labels=np.array(input_labels),
68
  multimask_output=False, # 单对豑输出
69
  )
 
70
  best_mask = masks[0]
71
+ # dilate
72
+ if len(input_points) > 1:
73
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
74
+ best_mask = cv2.dilate(best_mask, kernel, iterations=1)
75
+ best_mask = cv2.erode(best_mask, kernel, iterations=1)
76
  visible_mask = 255 * best_mask.astype(np.uint8)
77
  return visible_mask
78
 
 
397
 
398
  with gr.Row():
399
  with gr.Column():
400
+ input_image = gr.Image(type="numpy", label='Input Occlusion Image', interactive=False, height=300)
401
  with gr.Row():
402
  x_input = gr.Number(label="X Coordinate", value=0)
403
  y_input = gr.Number(label="Y Coordinate", value=0)
 
444
  input_image.upload(
445
  reset_image,
446
  [predictor, input_image],
447
+ [predictor, original_image]
448
  )
449
  see_button.click(
450
  see_point,