Spaces:
Running
on
Zero
Running
on
Zero
support all variants with added vairant detection
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import spaces
|
|
3 |
import os, sys, importlib.util, re
|
4 |
import gradio as gr
|
5 |
from PIL import Image
|
|
|
6 |
|
7 |
# βββ Monkey-patch mmdet to remove its mmcv-version assertion βββ
|
8 |
spec = importlib.util.find_spec('mmdet')
|
@@ -23,8 +24,33 @@ def load_inferencer(checkpoint_path=None, device=None):
|
|
23 |
kwargs = {'pose2d': 'rtmo', 'scope': 'mmpose', 'device': device, 'det_cat_ids': [0]}
|
24 |
if checkpoint_path:
|
25 |
kwargs['pose2d_weights'] = checkpoint_path
|
|
|
|
|
|
|
26 |
return MMPoseInferencer(**kwargs)
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
# βββ Gradio prediction function βββ
|
29 |
@spaces.GPU()
|
30 |
def predict(image: Image.Image, checkpoint):
|
|
|
3 |
import os, sys, importlib.util, re
|
4 |
import gradio as gr
|
5 |
from PIL import Image
|
6 |
+
import torch
|
7 |
|
8 |
# βββ Monkey-patch mmdet to remove its mmcv-version assertion βββ
|
9 |
spec = importlib.util.find_spec('mmdet')
|
|
|
24 |
kwargs = {'pose2d': 'rtmo', 'scope': 'mmpose', 'device': device, 'det_cat_ids': [0]}
|
25 |
if checkpoint_path:
|
26 |
kwargs['pose2d_weights'] = checkpoint_path
|
27 |
+
# detect model variant
|
28 |
+
variant = detect_rtmo_variant(checkpoint_path)
|
29 |
+
kwargs['pose2d'] = variant
|
30 |
return MMPoseInferencer(**kwargs)
|
31 |
|
32 |
+
def detect_rtmo_variant(checkpoint_path: str) -> str:
|
33 |
+
"""
|
34 |
+
Inspect an RTMO .pth checkpoint and return its variant alias:
|
35 |
+
one of 'rtmo-l', 'rtmo-m', 'rtmo-s', 'rtmo-t', or 'unknown'.
|
36 |
+
"""
|
37 |
+
ckpt = torch.load(checkpoint_path, map_location='cpu')
|
38 |
+
state_dict = ckpt.get('state_dict', ckpt)
|
39 |
+
|
40 |
+
key = 'backbone.stem.conv.conv.weight'
|
41 |
+
if key not in state_dict:
|
42 |
+
raise KeyError(f"Cannot find '{key}' in checkpoint.")
|
43 |
+
|
44 |
+
out_ch = state_dict[key].shape[0]
|
45 |
+
|
46 |
+
mapping = {
|
47 |
+
24: "rtmo-t_8xb32-600e_body7-416x416",
|
48 |
+
32: "rtmo-s_8xb32-600e_body7-640x640",
|
49 |
+
48: "rtmo-m_16xb16-600e_body7-640x640",
|
50 |
+
64: "rtmo-l_16xb16-600e_body7-640x640",
|
51 |
+
}
|
52 |
+
return mapping.get(out_ch, f'unknown (stem out_channels={out_ch})')
|
53 |
+
|
54 |
# βββ Gradio prediction function βββ
|
55 |
@spaces.GPU()
|
56 |
def predict(image: Image.Image, checkpoint):
|