Luigi commited on
Commit
f5a0539
Β·
1 Parent(s): cea6d9d

add gradio ui

Browse files
Files changed (2) hide show
  1. README.md +18 -1
  2. app.py +73 -0
README.md CHANGED
@@ -11,4 +11,21 @@ license: apache-2.0
11
  short_description: RTMO PyTorch Checkpoint Tester
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  short_description: RTMO PyTorch Checkpoint Tester
12
  ---
13
 
14
+ # RTMO PyTorch Checkpoint Tester
15
+
16
+ This HuggingFace Space runs the RTMO (Real-Time Multi-Person) 2D pose estimation model from OpenMMLab.
17
+
18
+ ## Usage
19
+ 1. Upload an image via the Gradio UI.
20
+ 2. (Optional) Provide a path or URL to your own RTMO PyTorch `.pth` checkpoint. If left blank, the default pretrained weights will be used.
21
+ 3. Click **Submit**. The annotated image with keypoints will be displayed.
22
+
23
+ ## Files
24
+ - **app.py**: Gradio application script that loads the RTMO model, runs inference, and displays results.
25
+ - **requirements.txt**: Python dependencies, including the patched MMCV build and MMPose.
26
+
27
+ ## Model
28
+ We use the `rtmo` alias defined in MMPose’s model zoo. To override, upload your own checkpoint.
29
+
30
+ ## Development
31
+ If you need to update dependencies or change the model, modify `requirements.txt` and `app.py` accordingly.
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import spaces
3
+ import os, sys, importlib.util, re
4
+ import gradio as gr
5
+ from PIL import Image
6
+
7
+ # β€”β€”β€” Monkey-patch mmdet to remove its mmcv-version assertion β€”β€”β€”
8
+ spec = importlib.util.find_spec('mmdet')
9
+ if spec and spec.origin:
10
+ src = open(spec.origin, encoding='utf-8').read()
11
+ # strip out the mmcv_minimum_version…assert… block up to __all__
12
+ patched = re.sub(r'(?ms)^[ \t]*mmcv_minimum_version.*?^__all__', '__all__', src)
13
+ m = importlib.util.module_from_spec(spec)
14
+ m.__loader__ = spec.loader
15
+ m.__file__ = spec.origin
16
+ m.__path__ = spec.submodule_search_locations
17
+ sys.modules['mmdet'] = m
18
+ exec(compile(patched, spec.origin, 'exec'), m.__dict__)
19
+
20
+ from mmpose.apis.inferencers import MMPoseInferencer
21
+
22
+ # β€”β€”β€” Initialize inferencer with default RTMO 2D model β€”β€”β€”
23
+ def load_inferencer(checkpoint_path=None, device=None):
24
+ kwargs = {'pose2d': 'rtmo', 'scope': 'mmpose', 'device': device, 'det_cat_ids': [0]}
25
+ if checkpoint_path:
26
+ kwargs['pose2d_weights'] = checkpoint_path
27
+ return MMPoseInferencer(**kwargs)
28
+
29
+ # β€”β€”β€” Gradio prediction function β€”β€”β€”
30
+ @spaces.GPU()
31
+ def predict(image: Image.Image, checkpoint: str = None):
32
+ # save upload to temp file
33
+ inp_path = "/tmp/upload.jpg"
34
+ image.save(inp_path)
35
+
36
+ vis_dir = "/tmp/vis"
37
+ os.makedirs(vis_dir, exist_ok=True)
38
+
39
+ inferencer = load_inferencer(checkpoint_path=checkpoint, device=None)
40
+ # run inference & visualization
41
+ for result in inferencer(
42
+ inputs=inp_path,
43
+ bbox_thr=0.1,
44
+ nms_thr=0.65,
45
+ pose_based_nms=True,
46
+ show=False,
47
+ vis_out_dir=vis_dir,
48
+ ):
49
+ pass
50
+
51
+ # return the first visualization
52
+ out_files = sorted(os.listdir(vis_dir))
53
+ if out_files:
54
+ return Image.open(os.path.join(vis_dir, out_files[0]))
55
+ return None
56
+
57
+ # β€”β€”β€” Build Gradio Interface β€”β€”β€”
58
+ demo = gr.Interface(
59
+ fn=predict,
60
+ inputs=[
61
+ gr.inputs.Image(type="pil", label="Upload Image"),
62
+ gr.inputs.Text(label="RTMO PyTorch Checkpoint Path (optional)")
63
+ ],
64
+ outputs=gr.outputs.Image(type="pil", label="Annotated Image"),
65
+ title="RTMO Pose Demo",
66
+ description="Upload an image, optionally supply a RTMO .pth checkpoint, and see 2D pose annotation.",
67
+ )
68
+
69
+ def main():
70
+ demo.launch()
71
+
72
+ if __name__ == "__main__":
73
+ main()