PengWeixuanSZU commited on
Commit
ca68585
·
verified ·
1 Parent(s): 6dd3fdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -244,9 +244,8 @@ def track_video(n_frames,video_state):
244
  video_state["origin_images"] = images
245
  images = np.array(images)
246
 
247
- global video_predictor
248
- video_predictor=video_predictor.to("cuda")
249
- inference_state = video_predictor.init_state(images=images/255, device="cuda")
250
  video_state["inference_state"] = inference_state
251
 
252
  if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
@@ -254,7 +253,7 @@ def track_video(n_frames,video_state):
254
  else:
255
  mask = torch.from_numpy(video_state["masks"][0])
256
 
257
- video_predictor.add_new_mask(
258
  inference_state=inference_state,
259
  frame_idx=0,
260
  obj_id=obj_id,
@@ -265,7 +264,7 @@ def track_video(n_frames,video_state):
265
  mask_frames = []
266
  color = np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0
267
  color = color[None, None, :]
268
- for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
269
  frame = images[out_frame_idx].astype(np.float32) / 255.0
270
  mask = np.zeros((H, W, 3), dtype=np.float32)
271
  for i, logit in enumerate(out_mask_logits):
 
244
  video_state["origin_images"] = images
245
  images = np.array(images)
246
 
247
+ video_predictor_local=video_predictor.to("cuda")
248
+ inference_state = video_predictor_local.init_state(images=images/255, device="cuda")
 
249
  video_state["inference_state"] = inference_state
250
 
251
  if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
 
253
  else:
254
  mask = torch.from_numpy(video_state["masks"][0])
255
 
256
+ video_predictor_local.add_new_mask(
257
  inference_state=inference_state,
258
  frame_idx=0,
259
  obj_id=obj_id,
 
264
  mask_frames = []
265
  color = np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0
266
  color = color[None, None, :]
267
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video(inference_state):
268
  frame = images[out_frame_idx].astype(np.float32) / 255.0
269
  mask = np.zeros((H, W, 3), dtype=np.float32)
270
  for i, logit in enumerate(out_mask_logits):