kristyc commited on
Commit
18a5bd9
·
1 Parent(s): ff3b552

Add more granular model configuration options

Browse files
Files changed (1) hide show
  1. app.py +79 -9
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  from matplotlib.pyplot import draw
3
  import mediapipe as mp
@@ -5,12 +6,12 @@ import numpy as np
5
  import tempfile
6
  import mediapy as media
7
  import log_utils
 
8
 
9
  logger = log_utils.get_logger()
10
 
11
  mp_hands = mp.solutions.hands
12
  mp_hands_connections = mp.solutions.hands_connections
13
- hands = mp_hands.Hands()
14
  mp_draw = mp.solutions.drawing_utils
15
 
16
  connections = {
@@ -23,27 +24,57 @@ connections = {
23
  'HAND_PINKY_FINGER_CONNECTIONS': mp_hands_connections.HAND_PINKY_FINGER_CONNECTIONS,
24
  }
25
 
26
- def draw_landmarks(img, selected_connections, draw_background):
27
- results = hands.process(img)
 
 
 
 
 
 
 
 
 
 
28
  output_img = img if draw_background else np.zeros_like(img)
29
  if results.multi_hand_landmarks:
30
  for hand_landmarks in results.multi_hand_landmarks:
31
  mp_draw.draw_landmarks(output_img, hand_landmarks, connections[selected_connections])
32
  return output_img
33
 
34
- def process_image(img, selected_connections, draw_background):
 
 
 
 
 
 
 
 
 
35
  logger.info(f"Processing image with connections: {selected_connections}, draw background: {draw_background}")
36
- return draw_landmarks(img, selected_connections, draw_background)
 
37
 
38
- def process_video(video_path, selected_connections, draw_background):
 
 
 
 
 
 
 
 
 
39
  logger.info(f"Processing video with connections: {selected_connections}, draw background: {draw_background}")
 
40
  with tempfile.NamedTemporaryFile() as f:
41
  out_path = f"{f.name}.{video_path.split('.')[-1]}"
42
  with media.VideoReader(video_path) as r:
43
  with media.VideoWriter(
44
  out_path, shape=r.shape, fps=r.fps, bps=r.bps) as w:
45
  for image in r:
46
- w.add_image(draw_landmarks(image, selected_connections, draw_background))
47
  return out_path
48
 
49
 
@@ -56,7 +87,31 @@ with demo:
56
  This is a demo of hand and finger tracking using [Google's MediaPipe](https://google.github.io/mediapipe/solutions/hands.html).
57
  """)
58
 
59
- with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  draw_background = gr.Checkbox(value=True, label="Draw background?")
61
  connection_keys = list(connections.keys())
62
  selected_connections = gr.Dropdown(
@@ -64,6 +119,10 @@ with demo:
64
  choices=connection_keys,
65
  value=connection_keys[0],
66
  )
 
 
 
 
67
  with gr.Tabs():
68
  with gr.TabItem(label="Upload an image"):
69
  uploaded_image = gr.Image(type="numpy")
@@ -78,12 +137,23 @@ with demo:
78
  uploaded_video = gr.Video(format="mp4")
79
  submit_uploaded_video = gr.Button(value="Process Video")
80
 
 
 
 
81
  with gr.Column():
82
  processed_video = gr.Video()
83
  processed_image = gr.Image()
84
 
85
  gr.Markdown('<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=kristyc.mediapipe-hands" />')
86
- setting_inputs = [selected_connections, draw_background]
 
 
 
 
 
 
 
 
87
  submit_uploaded_image.click(fn=process_image, inputs=[uploaded_image, *setting_inputs], outputs=[processed_image])
88
  submit_camera_picture.click(fn=process_image, inputs=[camera_picture, *setting_inputs], outputs=[processed_image])
89
  submit_recorded_video.click(fn=process_video, inputs=[recorded_video, *setting_inputs], outputs=[processed_video])
 
1
+ from os import stat
2
  import gradio as gr
3
  from matplotlib.pyplot import draw
4
  import mediapipe as mp
 
6
  import tempfile
7
  import mediapy as media
8
  import log_utils
9
+ from functools import lru_cache
10
 
11
  logger = log_utils.get_logger()
12
 
13
  mp_hands = mp.solutions.hands
14
  mp_hands_connections = mp.solutions.hands_connections
 
15
  mp_draw = mp.solutions.drawing_utils
16
 
17
  connections = {
 
24
  'HAND_PINKY_FINGER_CONNECTIONS': mp_hands_connections.HAND_PINKY_FINGER_CONNECTIONS,
25
  }
26
 
27
+ @lru_cache(maxsize=10)
28
+ def get_model(static_image_mode, max_num_hands, model_complexity, min_detection_conf, min_tracking_conf):
29
+ return mp_hands.Hands(
30
+ static_image_mode=static_image_mode,
31
+ max_num_hands=max_num_hands,
32
+ model_complexity=model_complexity,
33
+ min_detection_confidence=min_detection_conf,
34
+ min_tracking_confidence=min_tracking_conf,
35
+ )
36
+
37
+ def draw_landmarks(model, img, selected_connections, draw_background):
38
+ results = model.process(img)
39
  output_img = img if draw_background else np.zeros_like(img)
40
  if results.multi_hand_landmarks:
41
  for hand_landmarks in results.multi_hand_landmarks:
42
  mp_draw.draw_landmarks(output_img, hand_landmarks, connections[selected_connections])
43
  return output_img
44
 
45
+ def process_image(
46
+ img,
47
+ static_image_mode,
48
+ max_num_hands,
49
+ model_complexity,
50
+ min_detection_conf,
51
+ min_tracking_conf,
52
+ selected_connections,
53
+ draw_background,
54
+ ):
55
  logger.info(f"Processing image with connections: {selected_connections}, draw background: {draw_background}")
56
+ model = get_model(static_image_mode, max_num_hands, model_complexity, min_detection_conf, min_tracking_conf)
57
+ return draw_landmarks(model, img, selected_connections, draw_background)
58
 
59
+ def process_video(
60
+ video_path,
61
+ static_image_mode,
62
+ max_num_hands,
63
+ model_complexity,
64
+ min_detection_conf,
65
+ min_tracking_conf,
66
+ selected_connections,
67
+ draw_background,
68
+ ):
69
  logger.info(f"Processing video with connections: {selected_connections}, draw background: {draw_background}")
70
+ model = get_model(static_image_mode, max_num_hands, model_complexity, min_detection_conf, min_tracking_conf)
71
  with tempfile.NamedTemporaryFile() as f:
72
  out_path = f"{f.name}.{video_path.split('.')[-1]}"
73
  with media.VideoReader(video_path) as r:
74
  with media.VideoWriter(
75
  out_path, shape=r.shape, fps=r.fps, bps=r.bps) as w:
76
  for image in r:
77
+ w.add_image(draw_landmarks(model, image, selected_connections, draw_background))
78
  return out_path
79
 
80
 
 
87
  This is a demo of hand and finger tracking using [Google's MediaPipe](https://google.github.io/mediapipe/solutions/hands.html).
88
  """)
89
 
90
+ with gr.Column():
91
+ gr.Markdown("""
92
+ ## Step 1: Configure the model
93
+ """)
94
+ with gr.Column():
95
+ static_image_mode = gr.Checkbox(label="Static image mode", value=False)
96
+ gr.Textbox(show_label=False,value="If unchecked, the solution treats the input images as a video stream. It will try to detect hands in the first input images, and upon a successful detection further localizes the hand landmarks. In subsequent images, once all max_num_hands hands are detected and the corresponding hand landmarks are localized, it simply tracks those landmarks without invoking another detection until it loses track of any of the hands. This reduces latency and is ideal for processing video frames. If checked, hand detection runs on every input image, ideal for processing a batch of static, possibly unrelated, images.")
97
+
98
+ max_num_hands = gr.Slider(label="Max number of hands to detect", value=1, minimum=1, maximum=10, step=1)
99
+
100
+ with gr.Column():
101
+ model_complexity = gr.Radio(label="Model complexity", choices=[0,1], value=1)
102
+ gr.Textbox(show_label=False, value="Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as inference latency generally go up with the model complexity.")
103
+
104
+ with gr.Column():
105
+ min_detection_conf = gr.Slider(label="Min detection confidence", value=0.5, minimum=0.0, maximum=1.0, step=0.1)
106
+ gr.Textbox(show_label=False, value="Minimum confidence value ([0.0, 1.0]) from the hand detection model for the detection to be considered successful.")
107
+
108
+ with gr.Column():
109
+ min_tracking_conf = gr.Slider(label="Min tracking confidence", value=0.5, minimum=0.0, maximum=1.0, step=0.1)
110
+ gr.Textbox(show_label=False, value="Minimum confidence value ([0.0, 1.0]) from the landmark-tracking model for the hand landmarks to be considered tracked successfully, or otherwise hand detection will be invoked automatically on the next input image. Setting it to a higher value can increase robustness of the solution, at the expense of a higher latency. Ignored if static_image_mode is true, where hand detection simply runs on every image.")
111
+
112
+ gr.Markdown("""
113
+ ## Step 2: Set processing parameters
114
+ """)
115
  draw_background = gr.Checkbox(value=True, label="Draw background?")
116
  connection_keys = list(connections.keys())
117
  selected_connections = gr.Dropdown(
 
119
  choices=connection_keys,
120
  value=connection_keys[0],
121
  )
122
+
123
+ gr.Markdown("""
124
+ ## Step 3: Select an image or video
125
+ """)
126
  with gr.Tabs():
127
  with gr.TabItem(label="Upload an image"):
128
  uploaded_image = gr.Image(type="numpy")
 
137
  uploaded_video = gr.Video(format="mp4")
138
  submit_uploaded_video = gr.Button(value="Process Video")
139
 
140
+ gr.Markdown("""
141
+ ## Step 4: View results
142
+ """)
143
  with gr.Column():
144
  processed_video = gr.Video()
145
  processed_image = gr.Image()
146
 
147
  gr.Markdown('<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=kristyc.mediapipe-hands" />')
148
+ setting_inputs = [
149
+ static_image_mode,
150
+ max_num_hands,
151
+ model_complexity,
152
+ min_detection_conf,
153
+ min_tracking_conf,
154
+ selected_connections,
155
+ draw_background,
156
+ ]
157
  submit_uploaded_image.click(fn=process_image, inputs=[uploaded_image, *setting_inputs], outputs=[processed_image])
158
  submit_camera_picture.click(fn=process_image, inputs=[camera_picture, *setting_inputs], outputs=[processed_image])
159
  submit_recorded_video.click(fn=process_video, inputs=[recorded_video, *setting_inputs], outputs=[processed_video])