Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,311 Bytes
f5a0539 db758db f5a0539 86fc76b 341e0f2 db758db 36190e5 341e0f2 36190e5 f5a0539 341e0f2 73f03cf 341e0f2 86fc76b 341e0f2 86fc76b 341e0f2 4067901 341e0f2 86fc76b c663473 f5a0539 341e0f2 db758db 341e0f2 db758db 341e0f2 db758db f5a0539 db758db f5a0539 db758db 6452205 db758db f5a0539 aca2590 f5a0539 db758db f5a0539 db758db f5a0539 c663473 f5a0539 4067901 db758db 4067901 db758db 4067901 db758db c663473 db758db a32a9b3 c663473 db758db c663473 a32a9b3 c663473 db758db f5a0539 db758db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
#!/usr/bin/env python3
import spaces
import os
import sys
import importlib.util
import re
import gradio as gr
from PIL import Image
import torch
import requests # for downloading remote checkpoints
import shutil
# CUDA info
try:
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU device: {torch.cuda.get_device_name(0)}")
except:
print('CUDA is not available !')
# βββ Monkey-patch mmdet to remove its mmcv-version assertion βββ
spec = importlib.util.find_spec('mmdet')
if spec and spec.origin:
src = open(spec.origin, encoding='utf-8').read()
patched = re.sub(r'(?ms)^[ \t]*mmcv_minimum_version.*?^__all__', '__all__', src)
m = importlib.util.module_from_spec(spec)
m.__loader__ = spec.loader
m.__file__ = spec.origin
m.__path__ = spec.submodule_search_locations
sys.modules['mmdet'] = m
exec(compile(patched, spec.origin, 'exec'), m.__dict__)
from mmpose.apis.inferencers import MMPoseInferencer
# Remote checkpoints
REMOTE_CHECKPOINTS = {
# COCO-trained
"rtmo-s_8xb32-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-600e_coco-640x640-8db55a59_20231211.pth",
"rtmo-m_16xb16-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-m_16xb16-600e_coco-640x640-6f4e0306_20231211.pth",
"rtmo-l_16xb16-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-600e_coco-640x640-516a421f_20231211.pth",
# BODY7-trained
"rtmo-t_8xb32-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-t_8xb32-600e_body7-416x416-f48f75cb_20231219.pth",
"rtmo-s_8xb32-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-600e_body7-640x640-dac2bf74_20231211.pth",
"rtmo-m_16xb16-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-m_16xb16-600e_body7-640x640-39e78cc4_20231211.pth",
"rtmo-l_16xb16-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-600e_body7-640x640-b37118ce_20231211.pth",
# CrowdPose-trained
"rtmo-s_8xb32-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-700e_crowdpose-640x640-79f81c0d_20231211.pth",
"rtmo-m_16xb16-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rrtmo-m_16xb16-700e_crowdpose-640x640-0eaf670d_20231211.pth",
"rtmo-l_16xb16-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-700e_crowdpose-640x640-1008211f_20231211.pth",
# Retrainable from HF repo
"rtmo-s_coco_retrainable": "https://huggingface.co/Luigi/Retrainable-RTMO-s/resolve/main/rtmo-s_coco_retrainable.pth",
"rtmo-s_body6_retrainable": "https://huggingface.co/Luigi/Retrainable-RTMO-s/resolve/main/body6_epoch_600.pth",
}
# Variants for inference (prefixes)
VARIANT_PREFIX = {
24: "rtmo-t_8xb32-600e_body7-416x416",
32: "rtmo-s_8xb32-600e_body7-640x640",
48: "rtmo-m_16xb16-600e_body7-640x640",
64: "rtmo-l_16xb16-600e_body7-640x640",
}
# βββ Helper: download checkpoint if remote βββ
def get_checkpoint(path_or_key: str) -> str:
if path_or_key in REMOTE_CHECKPOINTS:
url = REMOTE_CHECKPOINTS[path_or_key]
local_path = f"/tmp/{path_or_key}.pth"
if not os.path.exists(local_path):
r = requests.get(url, stream=True)
with open(local_path, 'wb') as f:
for chunk in r.iter_content(1024):
f.write(chunk)
return local_path
return path_or_key
# βββ Detect variant alias from checkpoint βββ
def detect_rtmo_variant(checkpoint_path: str) -> str:
ckpt = torch.load(checkpoint_path, map_location='cpu')
state_dict = ckpt.get('state_dict', ckpt)
key = 'backbone.stem.conv.conv.weight'
if key not in state_dict:
raise KeyError(f"Cannot find '{key}' in checkpoint.")
out_ch = state_dict[key].shape[0]
return VARIANT_PREFIX.get(out_ch, 'rtmo-s_8xb32-600e_body7-640x640')
# βββ Load inferencer βββ
def load_inferencer(checkpoint_path=None, device=None):
kwargs = {'scope': 'mmpose', 'device': device, 'det_cat_ids': [0]}
if checkpoint_path:
variant = detect_rtmo_variant(checkpoint_path)
kwargs['pose2d'] = variant
kwargs['pose2d_weights'] = checkpoint_path
else:
kwargs['pose2d'] = 'rtmo'
return MMPoseInferencer(**kwargs)
# ββββ Prediction function ββββ
@spaces.GPU()
def predict(image: Image.Image,
video, # new video input
remote_ckpt: str,
upload_ckpt,
bbox_thr: float,
nms_thr: float):
# 1) Write image or pick up video file
if video:
# Gradio Video can come in as a filepath string or dict
if isinstance(video, dict) and 'name' in video:
inp_path = video['name']
elif hasattr(video, "name"):
inp_path = video.name
else:
inp_path = video
else:
inp_path = "/tmp/upload.jpg"
image.save(inp_path)
# 2) Determine by extension if this is video
ext = os.path.splitext(inp_path)[1].lower()
is_video = ext in (".mp4", ".mov", ".avi", ".mkv", ".webm")
# checkpoint selection
if upload_ckpt:
ckpt_path = upload_ckpt.name
active = os.path.basename(ckpt_path)
else:
ckpt_path = get_checkpoint(remote_ckpt)
active = remote_ckpt
# prepare (and clear) output dir
vis_dir = "/tmp/vis"
if os.path.exists(vis_dir):
shutil.rmtree(vis_dir)
os.makedirs(vis_dir, exist_ok=True)
# run inferencer (handles both image & video)
inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
for _ in inferencer(
inputs=inp_path,
bbox_thr=bbox_thr,
nms_thr=nms_thr,
pose_based_nms=True,
show=False,
vis_out_dir=vis_dir,
):
pass
# collect and return results
out_files = sorted(os.listdir(vis_dir))
if is_video:
# return only the annotated video path
out_vid = next((f for f in out_files if f.lower().endswith((".mp4", ".mov", ".avi"))), None)
return None, os.path.join(vis_dir, out_vid) if out_vid else None, active
else:
# return only the annotated image
img_f = out_files[0] if out_files else None
vis_img = Image.open(os.path.join(vis_dir, img_f)) if img_f and not img_f.lower().endswith((".mp4", ".mov", ".avi")) else None
return vis_img, None, active
# ββββ Gradio UI ββββ
def main():
with gr.Blocks() as demo:
gr.Markdown("## RTMO Pose Demo")
with gr.Row():
with gr.Column(scale=1, min_width=300):
img_input = gr.Image(type="pil", label="Upload Image")
video_input = gr.Video(label="Upload Video")
remote_dd = gr.Dropdown(
label="Select Remote Checkpoint",
choices=list(REMOTE_CHECKPOINTS.keys()),
value=list(REMOTE_CHECKPOINTS.keys())[0]
)
upload_ckpt = gr.File(file_types=['.pth'], label="Or Upload Your Own Checkpoint (optional)")
bbox_thr = gr.Slider(0.0, 1.0, value=0.1, step=0.01, label="Bounding Box Threshold")
nms_thr = gr.Slider(0.0, 1.0, value=0.65, step=0.01, label="NMS Threshold")
run_btn = gr.Button("Run Inference")
with gr.Column(scale=2):
output_img = gr.Image(type="pil", label="Annotated Image", elem_id="output_image", interactive=False)
output_video = gr.Video(label="Annotated Video", interactive=False)
active_tb = gr.Textbox(label="Active Checkpoint", interactive=False)
# Examples for quick testing
gr.Examples(
examples=[
["https://images.pexels.com/photos/1858175/pexels-photo-1858175.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-s_coco_retrainable", None, 0.1, 0.65],
["https://images.pexels.com/photos/3779706/pexels-photo-3779706.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-t_8xb32-600e_body7", None, 0.1, 0.65],
["https://images.pexels.com/photos/220453/pexels-photo-220453.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-s_8xb32-600e_coco", None, 0.1, 0.65],
# 4th example: public-domain Rip Van Winkle (1896)
[None,
"https://archive.org/download/fred-otts-sneeze/Fred%20Ott%20Sneeze%201894%20GG%20Restore.mp4",
"rtmo-s_coco_retrainable", None, 0.1, 0.65],
],
inputs=[img_input, video_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
outputs=[output_img, output_video, active_tb],
fn=predict,
cache_examples=False,
label="Examples",
examples_per_page=4
)
run_btn.click(
predict,
inputs=[img_input, video_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
outputs=[output_img, output_video, active_tb]
)
demo.launch()
if __name__ == "__main__":
main() |