zhiweili commited on
Commit
5540863
·
1 Parent(s): 440fd96

change to selfie_multiclass

Browse files
app.py CHANGED
@@ -9,15 +9,17 @@ from scipy.ndimage import binary_dilation
9
  BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
10
  MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
11
 
12
- MODEL_PATH = "checkpoints/hair_segmenter.tflite"
13
  base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
14
  options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
15
  segmenter = vision.ImageSegmenter.create_from_options(options)
 
16
 
17
- def segment(input_image):
18
  image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
19
  segmentation_result = segmenter.segment(image)
20
  category_mask = segmentation_result.category_mask
 
21
 
22
  # Generate solid color images for showing the output segmentation mask.
23
  image_data = image.numpy_view()
@@ -37,6 +39,7 @@ with gr.Blocks() as app:
37
  with gr.Row():
38
  with gr.Column():
39
  input_image = gr.Image(type='pil', label='Upload image')
 
40
  submit_btn = gr.Button(value='Submit', variant='primary')
41
  with gr.Column():
42
  output_image = gr.Image(type='pil', label='Image Output')
@@ -45,6 +48,7 @@ with gr.Blocks() as app:
45
  fn=segment,
46
  inputs=[
47
  input_image,
 
48
  ],
49
  outputs=[output_image]
50
  )
 
9
  BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
10
  MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
11
 
12
+ MODEL_PATH = "checkpoints/selfie_multiclass_256x256.tflite"
13
  base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
14
  options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
15
  segmenter = vision.ImageSegmenter.create_from_options(options)
16
+ MASK_CATEGORY = segmenter.labels
17
 
18
+ def segment(input_image, category):
19
  image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
20
  segmentation_result = segmenter.segment(image)
21
  category_mask = segmentation_result.category_mask
22
+ category_mask = category_mask == MASK_CATEGORY.index(category)
23
 
24
  # Generate solid color images for showing the output segmentation mask.
25
  image_data = image.numpy_view()
 
39
  with gr.Row():
40
  with gr.Column():
41
  input_image = gr.Image(type='pil', label='Upload image')
42
+ category = gr.Dropdown(label='Category', choices=MASK_CATEGORY, value=MASK_CATEGORY[1])
43
  submit_btn = gr.Button(value='Submit', variant='primary')
44
  with gr.Column():
45
  output_image = gr.Image(type='pil', label='Image Output')
 
48
  fn=segment,
49
  inputs=[
50
  input_image,
51
+ category,
52
  ],
53
  outputs=[output_image]
54
  )
checkpoints/selfie_multiclass_256x256.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
3
+ size 16371837