Luigi commited on
Commit
aca2590
·
1 Parent(s): b0ba5a9

Let user configurate bbox_thr & nms_thr in GUI

Browse files
Files changed (2) hide show
  1. README.md +6 -1
  2. app.py +5 -3
README.md CHANGED
@@ -28,4 +28,9 @@ This HuggingFace Space runs the RTMO (Real-Time Multi-Person) 2D pose estimation
28
  We use the `rtmo` alias defined in MMPose’s model zoo. To override, upload your own checkpoint.
29
 
30
  ## Development
31
- If you need to update dependencies or change the model, modify `requirements.txt` and `app.py` accordingly.
 
 
 
 
 
 
28
  We use the `rtmo` alias defined in MMPose’s model zoo. To override, upload your own checkpoint.
29
 
30
  ## Development
31
+ If you need to update dependencies or change the model, modify `requirements.txt` and `app.py` accordingly.
32
+
33
+ ## Todos
34
+ 1. Let user configurate bbox_thr & nms_thr in GUI
35
+ 2. Support video input
36
+ 3. Support models in ONNX format via rtmlib
app.py CHANGED
@@ -56,7 +56,7 @@ def detect_rtmo_variant(checkpoint_path: str) -> str:
56
 
57
  # ——— Gradio prediction function ———
58
  @spaces.GPU()
59
- def predict(image: Image.Image, checkpoint):
60
  # save upload to temp file
61
  inp_path = "/tmp/upload.jpg"
62
  image.save(inp_path)
@@ -70,8 +70,8 @@ def predict(image: Image.Image, checkpoint):
70
  inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
71
  for result in inferencer(
72
  inputs=inp_path,
73
- bbox_thr=0.1,
74
- nms_thr=0.65,
75
  pose_based_nms=True,
76
  show=False,
77
  vis_out_dir=vis_dir,
@@ -89,6 +89,8 @@ demo = gr.Interface(
89
  inputs=[
90
  gr.Image(type="pil", label="Upload Image"),
91
  gr.File(file_types=['.pth'], label="Upload RTMO .pth Checkpoint (optional)")
 
 
92
  ],
93
  outputs=gr.Image(type="pil", label="Annotated Image"),
94
  title="RTMO Pose Demo",
 
56
 
57
  # ——— Gradio prediction function ———
58
  @spaces.GPU()
59
+ def predict(image: Image.Image, checkpoint, bbox_thr: float, nms_thr: float):
60
  # save upload to temp file
61
  inp_path = "/tmp/upload.jpg"
62
  image.save(inp_path)
 
70
  inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
71
  for result in inferencer(
72
  inputs=inp_path,
73
+ bbox_thr=bbox_thr,
74
+ nms_thr=nms_thr,
75
  pose_based_nms=True,
76
  show=False,
77
  vis_out_dir=vis_dir,
 
89
  inputs=[
90
  gr.Image(type="pil", label="Upload Image"),
91
  gr.File(file_types=['.pth'], label="Upload RTMO .pth Checkpoint (optional)")
92
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.1, label="Bounding Box Threshold (bbox_thr)"),
93
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.65, label="NMS Threshold (nms_thr)"),
94
  ],
95
  outputs=gr.Image(type="pil", label="Annotated Image"),
96
  title="RTMO Pose Demo",