zhiweili commited on
Commit
440fd96
·
1 Parent(s): 65dfb4d

添加mediapipe切割头发

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. app.py +52 -0
  3. checkpoints/hair_segmenter.tflite +0 -0
  4. requirements.txt +2 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .vscode
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import mediapipe as mp
3
+ import numpy as np
4
+ from PIL import Image
5
+ from mediapipe.tasks import python
6
+ from mediapipe.tasks.python import vision
7
+ from scipy.ndimage import binary_dilation
8
+
9
+ BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
10
+ MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
11
+
12
+ MODEL_PATH = "checkpoints/hair_segmenter.tflite"
13
+ base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
14
+ options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
15
+ segmenter = vision.ImageSegmenter.create_from_options(options)
16
+
17
+ def segment(input_image):
18
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
19
+ segmentation_result = segmenter.segment(image)
20
+ category_mask = segmentation_result.category_mask
21
+
22
+ # Generate solid color images for showing the output segmentation mask.
23
+ image_data = image.numpy_view()
24
+ fg_image = np.zeros(image_data.shape, dtype=np.uint8)
25
+ fg_image[:] = MASK_COLOR
26
+ bg_image = np.zeros(image_data.shape, dtype=np.uint8)
27
+ bg_image[:] = BG_COLOR
28
+
29
+ dilated_mask = binary_dilation(category_mask.numpy_view(), iterations=4)
30
+ condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2
31
+
32
+ output_image = np.where(condition, fg_image, bg_image)
33
+ output_image = Image.fromarray(output_image)
34
+ return output_image
35
+
36
+ with gr.Blocks() as app:
37
+ with gr.Row():
38
+ with gr.Column():
39
+ input_image = gr.Image(type='pil', label='Upload image')
40
+ submit_btn = gr.Button(value='Submit', variant='primary')
41
+ with gr.Column():
42
+ output_image = gr.Image(type='pil', label='Image Output')
43
+
44
+ submit_btn.click(
45
+ fn=segment,
46
+ inputs=[
47
+ input_image,
48
+ ],
49
+ outputs=[output_image]
50
+ )
51
+
52
+ app.launch(debug=False, show_error=True)
checkpoints/hair_segmenter.tflite ADDED
Binary file (782 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ mediapipe
2
+ gradio