Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -244,9 +244,8 @@ def track_video(n_frames,video_state):
|
|
244 |
video_state["origin_images"] = images
|
245 |
images = np.array(images)
|
246 |
|
247 |
-
|
248 |
-
|
249 |
-
inference_state = video_predictor.init_state(images=images/255, device="cuda")
|
250 |
video_state["inference_state"] = inference_state
|
251 |
|
252 |
if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
|
@@ -254,7 +253,7 @@ def track_video(n_frames,video_state):
|
|
254 |
else:
|
255 |
mask = torch.from_numpy(video_state["masks"][0])
|
256 |
|
257 |
-
|
258 |
inference_state=inference_state,
|
259 |
frame_idx=0,
|
260 |
obj_id=obj_id,
|
@@ -265,7 +264,7 @@ def track_video(n_frames,video_state):
|
|
265 |
mask_frames = []
|
266 |
color = np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0
|
267 |
color = color[None, None, :]
|
268 |
-
for out_frame_idx, out_obj_ids, out_mask_logits in
|
269 |
frame = images[out_frame_idx].astype(np.float32) / 255.0
|
270 |
mask = np.zeros((H, W, 3), dtype=np.float32)
|
271 |
for i, logit in enumerate(out_mask_logits):
|
|
|
244 |
video_state["origin_images"] = images
|
245 |
images = np.array(images)
|
246 |
|
247 |
+
video_predictor_local=video_predictor.to("cuda")
|
248 |
+
inference_state = video_predictor_local.init_state(images=images/255, device="cuda")
|
|
|
249 |
video_state["inference_state"] = inference_state
|
250 |
|
251 |
if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
|
|
|
253 |
else:
|
254 |
mask = torch.from_numpy(video_state["masks"][0])
|
255 |
|
256 |
+
video_predictor_local.add_new_mask(
|
257 |
inference_state=inference_state,
|
258 |
frame_idx=0,
|
259 |
obj_id=obj_id,
|
|
|
264 |
mask_frames = []
|
265 |
color = np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0
|
266 |
color = color[None, None, :]
|
267 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video(inference_state):
|
268 |
frame = images[out_frame_idx].astype(np.float32) / 255.0
|
269 |
mask = np.zeros((H, W, 3), dtype=np.float32)
|
270 |
for i, logit in enumerate(out_mask_logits):
|