Luigi commited on
Commit
c663473
Β·
1 Parent(s): 4067901

add 3 examples

Browse files
Files changed (1) hide show
  1. app.py +21 -15
app.py CHANGED
@@ -57,7 +57,6 @@ VARIANT_PREFIX = {
57
 
58
  # β€”β€”β€” Helper: download checkpoint if remote β€”β€”β€”
59
  def get_checkpoint(path_or_key: str) -> str:
60
- # If it's a key in REMOTE_CHECKPOINTS, download
61
  if path_or_key in REMOTE_CHECKPOINTS:
62
  url = REMOTE_CHECKPOINTS[path_or_key]
63
  local_path = f"/tmp/{path_or_key}.pth"
@@ -67,7 +66,6 @@ def get_checkpoint(path_or_key: str) -> str:
67
  for chunk in r.iter_content(1024):
68
  f.write(chunk)
69
  return local_path
70
- # Otherwise assume it's a local file path
71
  return path_or_key
72
 
73
  # β€”β€”β€” Detect variant alias from checkpoint β€”β€”β€”
@@ -88,34 +86,26 @@ def load_inferencer(checkpoint_path=None, device=None):
88
  kwargs['pose2d'] = variant
89
  kwargs['pose2d_weights'] = checkpoint_path
90
  else:
91
- # default to rtmo-s
92
  kwargs['pose2d'] = 'rtmo'
93
  return MMPoseInferencer(**kwargs)
94
 
95
- # β€”β€”β€” Gradio prediction function β€”β€”β€”
96
  @spaces.GPU()
97
  def predict(image: Image.Image,
98
  remote_ckpt: str,
99
  upload_ckpt,
100
  bbox_thr: float,
101
  nms_thr: float):
102
- # save input image
103
  inp_path = "/tmp/upload.jpg"
104
  image.save(inp_path)
105
-
106
- # choose checkpoint: upload overrides remote
107
  if upload_ckpt:
108
  ckpt_path = upload_ckpt.name
109
  active = os.path.basename(ckpt_path)
110
  else:
111
  ckpt_path = get_checkpoint(remote_ckpt)
112
  active = remote_ckpt
113
-
114
- # prepare output dir
115
  vis_dir = "/tmp/vis"
116
  os.makedirs(vis_dir, exist_ok=True)
117
-
118
- # run inference
119
  inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
120
  for result in inferencer(
121
  inputs=inp_path,
@@ -126,14 +116,11 @@ def predict(image: Image.Image,
126
  vis_out_dir=vis_dir,
127
  ):
128
  pass
129
-
130
- # load result image
131
  out_files = sorted(os.listdir(vis_dir))
132
  vis_img = Image.open(os.path.join(vis_dir, out_files[0])) if out_files else None
133
  return vis_img, active
134
 
135
- # Build Gradio UI with Blocks for improved layout
136
-
137
  def main():
138
  with gr.Blocks() as demo:
139
  gr.Markdown("## RTMO Pose Demo")
@@ -153,6 +140,25 @@ def main():
153
  output_img = gr.Image(type="pil", label="Annotated Image",
154
  elem_id="output_image", interactive=False)
155
  active_tb = gr.Textbox(label="Active Checkpoint", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  run_btn.click(predict,
157
  inputs=[img_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
158
  outputs=[output_img, active_tb])
 
57
 
58
  # β€”β€”β€” Helper: download checkpoint if remote β€”β€”β€”
59
  def get_checkpoint(path_or_key: str) -> str:
 
60
  if path_or_key in REMOTE_CHECKPOINTS:
61
  url = REMOTE_CHECKPOINTS[path_or_key]
62
  local_path = f"/tmp/{path_or_key}.pth"
 
66
  for chunk in r.iter_content(1024):
67
  f.write(chunk)
68
  return local_path
 
69
  return path_or_key
70
 
71
  # β€”β€”β€” Detect variant alias from checkpoint β€”β€”β€”
 
86
  kwargs['pose2d'] = variant
87
  kwargs['pose2d_weights'] = checkpoint_path
88
  else:
 
89
  kwargs['pose2d'] = 'rtmo'
90
  return MMPoseInferencer(**kwargs)
91
 
92
+ # —─── Prediction function ────
93
  @spaces.GPU()
94
  def predict(image: Image.Image,
95
  remote_ckpt: str,
96
  upload_ckpt,
97
  bbox_thr: float,
98
  nms_thr: float):
 
99
  inp_path = "/tmp/upload.jpg"
100
  image.save(inp_path)
 
 
101
  if upload_ckpt:
102
  ckpt_path = upload_ckpt.name
103
  active = os.path.basename(ckpt_path)
104
  else:
105
  ckpt_path = get_checkpoint(remote_ckpt)
106
  active = remote_ckpt
 
 
107
  vis_dir = "/tmp/vis"
108
  os.makedirs(vis_dir, exist_ok=True)
 
 
109
  inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
110
  for result in inferencer(
111
  inputs=inp_path,
 
116
  vis_out_dir=vis_dir,
117
  ):
118
  pass
 
 
119
  out_files = sorted(os.listdir(vis_dir))
120
  vis_img = Image.open(os.path.join(vis_dir, out_files[0])) if out_files else None
121
  return vis_img, active
122
 
123
+ # —─── Gradio UI ────
 
124
  def main():
125
  with gr.Blocks() as demo:
126
  gr.Markdown("## RTMO Pose Demo")
 
140
  output_img = gr.Image(type="pil", label="Annotated Image",
141
  elem_id="output_image", interactive=False)
142
  active_tb = gr.Textbox(label="Active Checkpoint", interactive=False)
143
+
144
+ # Examples for quick testing
145
+ gr.Examples(
146
+ examples=[
147
+ ["https://images.pexels.com/photos/1858175/pexels-photo-1858175.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=614",
148
+ "rtmo-s_coco_retrainable", None, 0.1, 0.65],
149
+ ["https://images.pexels.com/photos/3779706/pexels-photo-3779706.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=614",
150
+ "rtmo-t_8xb32-600e_body7", None, 0.1, 0.65],
151
+ ["https://images.pexels.com/photos/220453/pexels-photo-220453.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=614",
152
+ "rtmo-s_8xb32-600e_coco", None, 0.1, 0.65],
153
+ ],
154
+ inputs=[img_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
155
+ outputs=[output_img, active_tb],
156
+ fn=predict,
157
+ cache_examples=False,
158
+ label="Examples",
159
+ examples_per_page=3
160
+ )
161
+
162
  run_btn.click(predict,
163
  inputs=[img_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
164
  outputs=[output_img, active_tb])