Spaces:
Running
on
Zero
Running
on
Zero
add 3 examples
Browse files
app.py
CHANGED
@@ -57,7 +57,6 @@ VARIANT_PREFIX = {
|
|
57 |
|
58 |
# βββ Helper: download checkpoint if remote βββ
|
59 |
def get_checkpoint(path_or_key: str) -> str:
|
60 |
-
# If it's a key in REMOTE_CHECKPOINTS, download
|
61 |
if path_or_key in REMOTE_CHECKPOINTS:
|
62 |
url = REMOTE_CHECKPOINTS[path_or_key]
|
63 |
local_path = f"/tmp/{path_or_key}.pth"
|
@@ -67,7 +66,6 @@ def get_checkpoint(path_or_key: str) -> str:
|
|
67 |
for chunk in r.iter_content(1024):
|
68 |
f.write(chunk)
|
69 |
return local_path
|
70 |
-
# Otherwise assume it's a local file path
|
71 |
return path_or_key
|
72 |
|
73 |
# βββ Detect variant alias from checkpoint βββ
|
@@ -88,34 +86,26 @@ def load_inferencer(checkpoint_path=None, device=None):
|
|
88 |
kwargs['pose2d'] = variant
|
89 |
kwargs['pose2d_weights'] = checkpoint_path
|
90 |
else:
|
91 |
-
# default to rtmo-s
|
92 |
kwargs['pose2d'] = 'rtmo'
|
93 |
return MMPoseInferencer(**kwargs)
|
94 |
|
95 |
-
#
|
96 |
@spaces.GPU()
|
97 |
def predict(image: Image.Image,
|
98 |
remote_ckpt: str,
|
99 |
upload_ckpt,
|
100 |
bbox_thr: float,
|
101 |
nms_thr: float):
|
102 |
-
# save input image
|
103 |
inp_path = "/tmp/upload.jpg"
|
104 |
image.save(inp_path)
|
105 |
-
|
106 |
-
# choose checkpoint: upload overrides remote
|
107 |
if upload_ckpt:
|
108 |
ckpt_path = upload_ckpt.name
|
109 |
active = os.path.basename(ckpt_path)
|
110 |
else:
|
111 |
ckpt_path = get_checkpoint(remote_ckpt)
|
112 |
active = remote_ckpt
|
113 |
-
|
114 |
-
# prepare output dir
|
115 |
vis_dir = "/tmp/vis"
|
116 |
os.makedirs(vis_dir, exist_ok=True)
|
117 |
-
|
118 |
-
# run inference
|
119 |
inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
|
120 |
for result in inferencer(
|
121 |
inputs=inp_path,
|
@@ -126,14 +116,11 @@ def predict(image: Image.Image,
|
|
126 |
vis_out_dir=vis_dir,
|
127 |
):
|
128 |
pass
|
129 |
-
|
130 |
-
# load result image
|
131 |
out_files = sorted(os.listdir(vis_dir))
|
132 |
vis_img = Image.open(os.path.join(vis_dir, out_files[0])) if out_files else None
|
133 |
return vis_img, active
|
134 |
|
135 |
-
#
|
136 |
-
|
137 |
def main():
|
138 |
with gr.Blocks() as demo:
|
139 |
gr.Markdown("## RTMO Pose Demo")
|
@@ -153,6 +140,25 @@ def main():
|
|
153 |
output_img = gr.Image(type="pil", label="Annotated Image",
|
154 |
elem_id="output_image", interactive=False)
|
155 |
active_tb = gr.Textbox(label="Active Checkpoint", interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
run_btn.click(predict,
|
157 |
inputs=[img_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
|
158 |
outputs=[output_img, active_tb])
|
|
|
57 |
|
58 |
# βββ Helper: download checkpoint if remote βββ
|
59 |
def get_checkpoint(path_or_key: str) -> str:
|
|
|
60 |
if path_or_key in REMOTE_CHECKPOINTS:
|
61 |
url = REMOTE_CHECKPOINTS[path_or_key]
|
62 |
local_path = f"/tmp/{path_or_key}.pth"
|
|
|
66 |
for chunk in r.iter_content(1024):
|
67 |
f.write(chunk)
|
68 |
return local_path
|
|
|
69 |
return path_or_key
|
70 |
|
71 |
# βββ Detect variant alias from checkpoint βββ
|
|
|
86 |
kwargs['pose2d'] = variant
|
87 |
kwargs['pose2d_weights'] = checkpoint_path
|
88 |
else:
|
|
|
89 |
kwargs['pose2d'] = 'rtmo'
|
90 |
return MMPoseInferencer(**kwargs)
|
91 |
|
92 |
+
# ββββ Prediction function ββββ
|
93 |
@spaces.GPU()
|
94 |
def predict(image: Image.Image,
|
95 |
remote_ckpt: str,
|
96 |
upload_ckpt,
|
97 |
bbox_thr: float,
|
98 |
nms_thr: float):
|
|
|
99 |
inp_path = "/tmp/upload.jpg"
|
100 |
image.save(inp_path)
|
|
|
|
|
101 |
if upload_ckpt:
|
102 |
ckpt_path = upload_ckpt.name
|
103 |
active = os.path.basename(ckpt_path)
|
104 |
else:
|
105 |
ckpt_path = get_checkpoint(remote_ckpt)
|
106 |
active = remote_ckpt
|
|
|
|
|
107 |
vis_dir = "/tmp/vis"
|
108 |
os.makedirs(vis_dir, exist_ok=True)
|
|
|
|
|
109 |
inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
|
110 |
for result in inferencer(
|
111 |
inputs=inp_path,
|
|
|
116 |
vis_out_dir=vis_dir,
|
117 |
):
|
118 |
pass
|
|
|
|
|
119 |
out_files = sorted(os.listdir(vis_dir))
|
120 |
vis_img = Image.open(os.path.join(vis_dir, out_files[0])) if out_files else None
|
121 |
return vis_img, active
|
122 |
|
123 |
+
# ββββ Gradio UI ββββ
|
|
|
124 |
def main():
|
125 |
with gr.Blocks() as demo:
|
126 |
gr.Markdown("## RTMO Pose Demo")
|
|
|
140 |
output_img = gr.Image(type="pil", label="Annotated Image",
|
141 |
elem_id="output_image", interactive=False)
|
142 |
active_tb = gr.Textbox(label="Active Checkpoint", interactive=False)
|
143 |
+
|
144 |
+
# Examples for quick testing
|
145 |
+
gr.Examples(
|
146 |
+
examples=[
|
147 |
+
["https://images.pexels.com/photos/1858175/pexels-photo-1858175.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=614",
|
148 |
+
"rtmo-s_coco_retrainable", None, 0.1, 0.65],
|
149 |
+
["https://images.pexels.com/photos/3779706/pexels-photo-3779706.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=614",
|
150 |
+
"rtmo-t_8xb32-600e_body7", None, 0.1, 0.65],
|
151 |
+
["https://images.pexels.com/photos/220453/pexels-photo-220453.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=614",
|
152 |
+
"rtmo-s_8xb32-600e_coco", None, 0.1, 0.65],
|
153 |
+
],
|
154 |
+
inputs=[img_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
|
155 |
+
outputs=[output_img, active_tb],
|
156 |
+
fn=predict,
|
157 |
+
cache_examples=False,
|
158 |
+
label="Examples",
|
159 |
+
examples_per_page=3
|
160 |
+
)
|
161 |
+
|
162 |
run_btn.click(predict,
|
163 |
inputs=[img_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
|
164 |
outputs=[output_img, active_tb])
|