Luigi commited on
Commit
86fc76b
Β·
1 Parent(s): 6452205

support all variants with added vairant detection

Browse files
Files changed (1) hide show
  1. app.py +26 -0
app.py CHANGED
@@ -3,6 +3,7 @@ import spaces
3
  import os, sys, importlib.util, re
4
  import gradio as gr
5
  from PIL import Image
 
6
 
7
  # β€”β€”β€” Monkey-patch mmdet to remove its mmcv-version assertion β€”β€”β€”
8
  spec = importlib.util.find_spec('mmdet')
@@ -23,8 +24,33 @@ def load_inferencer(checkpoint_path=None, device=None):
23
  kwargs = {'pose2d': 'rtmo', 'scope': 'mmpose', 'device': device, 'det_cat_ids': [0]}
24
  if checkpoint_path:
25
  kwargs['pose2d_weights'] = checkpoint_path
 
 
 
26
  return MMPoseInferencer(**kwargs)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # β€”β€”β€” Gradio prediction function β€”β€”β€”
29
  @spaces.GPU()
30
  def predict(image: Image.Image, checkpoint):
 
3
  import os, sys, importlib.util, re
4
  import gradio as gr
5
  from PIL import Image
6
+ import torch
7
 
8
  # β€”β€”β€” Monkey-patch mmdet to remove its mmcv-version assertion β€”β€”β€”
9
  spec = importlib.util.find_spec('mmdet')
 
24
  kwargs = {'pose2d': 'rtmo', 'scope': 'mmpose', 'device': device, 'det_cat_ids': [0]}
25
  if checkpoint_path:
26
  kwargs['pose2d_weights'] = checkpoint_path
27
+ # detect model variant
28
+ variant = detect_rtmo_variant(checkpoint_path)
29
+ kwargs['pose2d'] = variant
30
  return MMPoseInferencer(**kwargs)
31
 
32
+ def detect_rtmo_variant(checkpoint_path: str) -> str:
33
+ """
34
+ Inspect an RTMO .pth checkpoint and return its variant alias:
35
+ one of 'rtmo-l', 'rtmo-m', 'rtmo-s', 'rtmo-t', or 'unknown'.
36
+ """
37
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
38
+ state_dict = ckpt.get('state_dict', ckpt)
39
+
40
+ key = 'backbone.stem.conv.conv.weight'
41
+ if key not in state_dict:
42
+ raise KeyError(f"Cannot find '{key}' in checkpoint.")
43
+
44
+ out_ch = state_dict[key].shape[0]
45
+
46
+ mapping = {
47
+ 24: "rtmo-t_8xb32-600e_body7-416x416",
48
+ 32: "rtmo-s_8xb32-600e_body7-640x640",
49
+ 48: "rtmo-m_16xb16-600e_body7-640x640",
50
+ 64: "rtmo-l_16xb16-600e_body7-640x640",
51
+ }
52
+ return mapping.get(out_ch, f'unknown (stem out_channels={out_ch})')
53
+
54
  # β€”β€”β€” Gradio prediction function β€”β€”β€”
55
  @spaces.GPU()
56
  def predict(image: Image.Image, checkpoint):