Spaces:
Paused
Paused
move inference_states out of gr.State
Browse files
app.py
CHANGED
|
@@ -73,6 +73,7 @@ OBJ_ID = 0
|
|
| 73 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
| 74 |
model_cfg = "edgetam.yaml"
|
| 75 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
def get_video_fps(video_path):
|
|
@@ -89,15 +90,17 @@ def get_video_fps(video_path):
|
|
| 89 |
return fps
|
| 90 |
|
| 91 |
|
| 92 |
-
def reset(
|
| 93 |
predictor.to("cpu")
|
| 94 |
session_state["input_points"] = []
|
| 95 |
session_state["input_labels"] = []
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
| 98 |
session_state["first_frame"] = None
|
| 99 |
session_state["all_frames"] = None
|
| 100 |
-
|
| 101 |
return (
|
| 102 |
None,
|
| 103 |
gr.update(open=True),
|
|
@@ -112,8 +115,9 @@ def clear_points(session_state):
|
|
| 112 |
predictor.to("cpu")
|
| 113 |
session_state["input_points"] = []
|
| 114 |
session_state["input_labels"] = []
|
| 115 |
-
|
| 116 |
-
|
|
|
|
| 117 |
return (
|
| 118 |
session_state["first_frame"],
|
| 119 |
None,
|
|
@@ -168,7 +172,9 @@ def preprocess_video_in(video_path, session_state):
|
|
| 168 |
session_state["first_frame"] = copy.deepcopy(first_frame)
|
| 169 |
session_state["all_frames"] = all_frames
|
| 170 |
|
| 171 |
-
|
|
|
|
|
|
|
| 172 |
session_state["input_points"] = []
|
| 173 |
session_state["input_labels"] = []
|
| 174 |
|
|
@@ -230,8 +236,9 @@ def segment_with_points(
|
|
| 230 |
points = np.array(session_state["input_points"], dtype=np.float32)
|
| 231 |
# for labels, `1` means positive click and `0` means negative click
|
| 232 |
labels = np.array(session_state["input_labels"], np.int32)
|
|
|
|
| 233 |
_, _, out_mask_logits = predictor.add_new_points(
|
| 234 |
-
inference_state=
|
| 235 |
frame_idx=0,
|
| 236 |
obj_id=OBJ_ID,
|
| 237 |
points=points,
|
|
@@ -270,10 +277,11 @@ def propagate_to_all(
|
|
| 270 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 271 |
torch.backends.cudnn.allow_tf32 = True
|
| 272 |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
|
|
| 273 |
if (
|
| 274 |
len(session_state["input_points"]) == 0
|
| 275 |
or video_in is None
|
| 276 |
-
or
|
| 277 |
):
|
| 278 |
return (
|
| 279 |
None,
|
|
@@ -286,7 +294,7 @@ def propagate_to_all(
|
|
| 286 |
) # video_segments contains the per-frame segmentation results
|
| 287 |
print("starting propagate_in_video")
|
| 288 |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
|
| 289 |
-
|
| 290 |
):
|
| 291 |
video_segments[out_frame_idx] = {
|
| 292 |
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
|
@@ -340,7 +348,6 @@ with gr.Blocks() as demo:
|
|
| 340 |
"all_frames": None,
|
| 341 |
"input_points": [],
|
| 342 |
"input_labels": [],
|
| 343 |
-
"inference_state": None,
|
| 344 |
}
|
| 345 |
)
|
| 346 |
|
|
|
|
| 73 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
| 74 |
model_cfg = "edgetam.yaml"
|
| 75 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
| 76 |
+
global_inference_states = {}
|
| 77 |
|
| 78 |
|
| 79 |
def get_video_fps(video_path):
|
|
|
|
| 90 |
return fps
|
| 91 |
|
| 92 |
|
| 93 |
+
def reset():
|
| 94 |
predictor.to("cpu")
|
| 95 |
session_state["input_points"] = []
|
| 96 |
session_state["input_labels"] = []
|
| 97 |
+
|
| 98 |
+
session_id = id(session_state)
|
| 99 |
+
if global_inference_states[session_id] is not None:
|
| 100 |
+
predictor.reset_state(global_inference_states[session_id])
|
| 101 |
session_state["first_frame"] = None
|
| 102 |
session_state["all_frames"] = None
|
| 103 |
+
global_inference_states[session_id] = None
|
| 104 |
return (
|
| 105 |
None,
|
| 106 |
gr.update(open=True),
|
|
|
|
| 115 |
predictor.to("cpu")
|
| 116 |
session_state["input_points"] = []
|
| 117 |
session_state["input_labels"] = []
|
| 118 |
+
session_id = id(session_state)
|
| 119 |
+
if global_inference_states[session_id]["tracking_has_started"]:
|
| 120 |
+
predictor.reset_state(global_inference_states[session_id])
|
| 121 |
return (
|
| 122 |
session_state["first_frame"],
|
| 123 |
None,
|
|
|
|
| 172 |
session_state["first_frame"] = copy.deepcopy(first_frame)
|
| 173 |
session_state["all_frames"] = all_frames
|
| 174 |
|
| 175 |
+
session_id = id(session_state)
|
| 176 |
+
global_inference_states[session_id] = predictor.init_state(video_path=video_path)
|
| 177 |
+
|
| 178 |
session_state["input_points"] = []
|
| 179 |
session_state["input_labels"] = []
|
| 180 |
|
|
|
|
| 236 |
points = np.array(session_state["input_points"], dtype=np.float32)
|
| 237 |
# for labels, `1` means positive click and `0` means negative click
|
| 238 |
labels = np.array(session_state["input_labels"], np.int32)
|
| 239 |
+
session_id = id(session_state)
|
| 240 |
_, _, out_mask_logits = predictor.add_new_points(
|
| 241 |
+
inference_state=global_inference_states[session_id],
|
| 242 |
frame_idx=0,
|
| 243 |
obj_id=OBJ_ID,
|
| 244 |
points=points,
|
|
|
|
| 277 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 278 |
torch.backends.cudnn.allow_tf32 = True
|
| 279 |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 280 |
+
session_id = id(session_state)
|
| 281 |
if (
|
| 282 |
len(session_state["input_points"]) == 0
|
| 283 |
or video_in is None
|
| 284 |
+
or global_inference_states[session_id] is None
|
| 285 |
):
|
| 286 |
return (
|
| 287 |
None,
|
|
|
|
| 294 |
) # video_segments contains the per-frame segmentation results
|
| 295 |
print("starting propagate_in_video")
|
| 296 |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
|
| 297 |
+
global_inference_states[session_id]
|
| 298 |
):
|
| 299 |
video_segments[out_frame_idx] = {
|
| 300 |
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
|
|
|
| 348 |
"all_frames": None,
|
| 349 |
"input_points": [],
|
| 350 |
"input_labels": [],
|
|
|
|
| 351 |
}
|
| 352 |
)
|
| 353 |
|