Spaces:
bla
/
Runtime error

bla commited on
Commit
2a466e4
·
verified ·
1 Parent(s): 99d098f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -25
app.py CHANGED
@@ -76,8 +76,9 @@ OBJ_ID = 0
76
  sam2_checkpoint = "checkpoints/edgetam.pt"
77
  model_cfg = "edgetam.yaml"
78
  # Ensure predictor is explicitly built for CPU
 
79
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
80
- predictor.to("cpu") # Explicitly move to CPU, though device="cpu" should handle it
81
  print("predictor loaded on CPU")
82
 
83
  # Removed autocast block for maximum CPU compatibility
@@ -121,6 +122,7 @@ def preprocess_video_in(video_path, session_state):
121
  "input_points": [],
122
  "input_labels": [],
123
  "inference_state": None,
 
124
  }
125
  )
126
 
@@ -143,6 +145,7 @@ def preprocess_video_in(video_path, session_state):
143
  "input_points": [],
144
  "input_labels": [],
145
  "inference_state": None,
 
146
  }
147
  )
148
 
@@ -178,16 +181,18 @@ def preprocess_video_in(video_path, session_state):
178
  "input_points": [],
179
  "input_labels": [],
180
  "inference_state": None,
 
181
  }
182
  )
183
 
184
-
185
  session_state["first_frame"] = copy.deepcopy(first_frame) # Store a copy
186
  session_state["all_frames"] = all_frames
 
187
  session_state["input_points"] = []
188
  session_state["input_labels"] = []
189
- # Initialize state explicitly for CPU
190
- session_state["inference_state"] = predictor.init_state(video_path=video_path, device="cpu")
191
  print("Video loaded and predictor state initialized.")
192
 
193
  return [
@@ -213,9 +218,10 @@ def reset(session_state):
213
  predictor.reset_state(session_state["inference_state"])
214
  # After reset, we also discard the state object as a new video might be loaded
215
  session_state["inference_state"] = None
216
- # Clear frames
217
  session_state["first_frame"] = None
218
  session_state["all_frames"] = None
 
219
 
220
  # Update UI elements to their initial state
221
  return (
@@ -238,18 +244,19 @@ def clear_points(session_state):
238
  session_state["input_points"] = []
239
  session_state["input_labels"] = []
240
 
241
- # If inference state exists, reset it. This clears internal masks/features
242
  # but keeps the video context initialized by preprocess_video_in.
243
  if session_state["inference_state"] is not None:
244
  predictor.reset_state(session_state["inference_state"])
245
- # After resetting the state, we need to re-initialize it to be ready for new points.
246
- # Pass the original video path stored in the state.
247
- if "video_path" in session_state["inference_state"] and session_state["inference_state"]["video_path"] is not None:
248
- session_state["inference_state"] = predictor.init_state(video_path=session_state["inference_state"]["video_path"])
 
 
249
  else:
250
- # This case should ideally not happen if preprocess_video_in ran correctly
251
  print("Warning: Could not re-initialize state after clear_points (video_path missing).")
252
- session_state["inference_state"] = None
253
 
254
 
255
  # Re-render the points_map with no points drawn (just the first frame)
@@ -324,6 +331,7 @@ def segment_with_points(
324
  points = np.array(session_state["input_points"], dtype=np.float32)
325
  labels = np.array(session_state["input_labels"], np.int32)
326
 
 
327
  points_tensor = torch.tensor(points, dtype=torch.float32, device="cpu").unsqueeze(0) # Add batch dim
328
  labels_tensor = torch.tensor(labels, dtype=torch.int32, device="cpu").unsqueeze(0) # Add batch dim
329
 
@@ -340,6 +348,7 @@ def segment_with_points(
340
 
341
  # Process logits: detach from graph, move to CPU, apply threshold
342
  # out_mask_logits is [batch_size, H, W] (batch_size=1 here)
 
343
  mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Apply threshold and get the single mask tensor [H, W]
344
  mask_numpy = mask_tensor.numpy() # Convert to numpy
345
 
@@ -363,7 +372,8 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
363
  # Ensure mask is a numpy array (and boolean)
364
  if isinstance(mask, torch.Tensor):
365
  mask = mask.detach().cpu().numpy() # Ensure it's on CPU and converted to numpy
366
- mask = mask.astype(bool) # Ensure mask is boolean
 
367
 
368
  if random_color:
369
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) # RGBA with 0.6 alpha
@@ -374,15 +384,15 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
374
 
375
  # Ensure mask has H, W dimensions
376
  if mask.ndim == 3:
377
- mask = mask.squeeze() # Remove singular dimensions
378
  if mask.ndim != 2:
379
  print(f"Warning: show_mask received mask with shape {mask.shape}. Expected 2D.")
380
  # Create an empty transparent image if mask shape is unexpected
 
381
  if convert_to_image:
382
- return Image.fromarray(np.zeros((*mask.shape[:2], 4), dtype=np.uint8), "RGBA")
383
  else:
384
- return np.zeros((*mask.shape[:2], 4), dtype=np.uint8)
385
-
386
 
387
  h, w = mask.shape
388
  # Create an RGBA image from the mask and color
@@ -403,7 +413,9 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
403
 
404
  # Removed @spaces.GPU decorator
405
  def propagate_to_all(
406
- video_in, # Keep video_in path to potentially get FPS again if needed
 
 
407
  session_state,
408
  ):
409
  """Runs mask propagation through the video and generates the output video."""
@@ -413,6 +425,7 @@ def propagate_to_all(
413
  len(session_state["input_points"]) == 0 # Need at least one point
414
  or session_state["all_frames"] is None
415
  or session_state["inference_state"] is None
 
416
  ):
417
  print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.")
418
  return (
@@ -424,13 +437,16 @@ def propagate_to_all(
424
  # The generator yields (frame_idx, obj_ids, mask_logits)
425
  video_segments = {}
426
  try:
 
427
  for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
428
  session_state["inference_state"]
429
  ):
430
  # Process logits: detach from graph, move to CPU, convert to numpy boolean mask
431
  # Ensure tensor is on CPU before converting to numpy
432
  video_segments[out_frame_idx] = {
433
- out_obj_id: (out_mask_logits[i].detach().cpu() > 0.0).numpy()
 
 
434
  for i, out_obj_id in enumerate(out_obj_ids)
435
  }
436
  # Optional: print progress
@@ -447,7 +463,8 @@ def propagate_to_all(
447
 
448
  output_frames = []
449
  # Iterate through all original frames to generate output video
450
- for out_frame_idx in range(len(session_state["all_frames"])):
 
451
  original_frame_rgb = session_state["all_frames"][out_frame_idx]
452
  # Convert original frame to RGBA for compositing
453
  transparent_background = Image.fromarray(original_frame_rgb).convert("RGBA")
@@ -471,16 +488,17 @@ def propagate_to_all(
471
 
472
 
473
  # Define output path in a temporary directory
 
474
  unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness
475
  final_vid_filename = f"output_video_{unique_id}.mp4"
476
- # Use os.path.join for cross-platform compatibility
477
  final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename)
478
  print(f"Output video path: {final_vid_output_path}")
479
 
480
 
481
  # Create a video clip from the image sequence
482
  # Get original FPS or default
483
- original_fps = get_video_fps(video_in) # Re-get FPS from the input file path
 
484
  fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero
485
  print(f"Creating output video with FPS: {fps}")
486
 
@@ -512,7 +530,7 @@ def propagate_to_all(
512
  final_vid_output_path,
513
  codec="libx264",
514
  fps=fps, # Ensure correct FPS is used during writing
515
- preset="medium", # CPU optimization: 'fast', 'faster', 'veryfast' are options for speed
516
  threads="auto", # CPU optimization: Use multiple cores
517
  logger=None # Suppress moviepy output
518
  )
@@ -714,8 +732,8 @@ with gr.Blocks() as demo:
714
  ).then( # Then, run the propagation function
715
  fn=propagate_to_all,
716
  inputs=[
717
- video_in, # Get the input video path
718
- session_state, # Pass session state (contains frames, points, inference_state)
719
  ],
720
  outputs=[
721
  output_video, # Update output video player with result
 
76
  sam2_checkpoint = "checkpoints/edgetam.pt"
77
  model_cfg = "edgetam.yaml"
78
  # Ensure predictor is explicitly built for CPU
79
+ # The device is set here and with .to("cpu")
80
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
81
+ predictor.to("cpu") # Explicitly move to CPU after building
82
  print("predictor loaded on CPU")
83
 
84
  # Removed autocast block for maximum CPU compatibility
 
122
  "input_points": [],
123
  "input_labels": [],
124
  "inference_state": None,
125
+ "video_path": None,
126
  }
127
  )
128
 
 
145
  "input_points": [],
146
  "input_labels": [],
147
  "inference_state": None,
148
+ "video_path": None,
149
  }
150
  )
151
 
 
181
  "input_points": [],
182
  "input_labels": [],
183
  "inference_state": None,
184
+ "video_path": None,
185
  }
186
  )
187
 
188
+ # Update session state with frames and path
189
  session_state["first_frame"] = copy.deepcopy(first_frame) # Store a copy
190
  session_state["all_frames"] = all_frames
191
+ session_state["video_path"] = video_path # Store the path
192
  session_state["input_points"] = []
193
  session_state["input_labels"] = []
194
+ # Initialize state *without* the device argument
195
+ session_state["inference_state"] = predictor.init_state(video_path=video_path)
196
  print("Video loaded and predictor state initialized.")
197
 
198
  return [
 
218
  predictor.reset_state(session_state["inference_state"])
219
  # After reset, we also discard the state object as a new video might be loaded
220
  session_state["inference_state"] = None
221
+ # Clear frames and video path
222
  session_state["first_frame"] = None
223
  session_state["all_frames"] = None
224
+ session_state["video_path"] = None
225
 
226
  # Update UI elements to their initial state
227
  return (
 
244
  session_state["input_points"] = []
245
  session_state["input_labels"] = []
246
 
247
+ # Reset the predictor state if it exists. This clears internal masks/features
248
  # but keeps the video context initialized by preprocess_video_in.
249
  if session_state["inference_state"] is not None:
250
  predictor.reset_state(session_state["inference_state"])
251
+ # After resetting the state, if we still have the video path, re-initialize the state
252
+ # to be ready for new points on the same video.
253
+ if session_state["video_path"] is not None:
254
+ # Re-initialize state *without* the device argument
255
+ session_state["inference_state"] = predictor.init_state(video_path=session_state["video_path"])
256
+ print("Predictor state re-initialized after clearing points.")
257
  else:
 
258
  print("Warning: Could not re-initialize state after clear_points (video_path missing).")
259
+ session_state["inference_state"] = None # Ensure state is None if video_path is gone
260
 
261
 
262
  # Re-render the points_map with no points drawn (just the first frame)
 
331
  points = np.array(session_state["input_points"], dtype=np.float32)
332
  labels = np.array(session_state["input_labels"], np.int32)
333
 
334
+ # Ensure tensors are on CPU
335
  points_tensor = torch.tensor(points, dtype=torch.float32, device="cpu").unsqueeze(0) # Add batch dim
336
  labels_tensor = torch.tensor(labels, dtype=torch.int32, device="cpu").unsqueeze(0) # Add batch dim
337
 
 
348
 
349
  # Process logits: detach from graph, move to CPU, apply threshold
350
  # out_mask_logits is [batch_size, H, W] (batch_size=1 here)
351
+ # out_mask_logits[0] is the tensor for obj_id=OBJ_ID
352
  mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Apply threshold and get the single mask tensor [H, W]
353
  mask_numpy = mask_tensor.numpy() # Convert to numpy
354
 
 
372
  # Ensure mask is a numpy array (and boolean)
373
  if isinstance(mask, torch.Tensor):
374
  mask = mask.detach().cpu().numpy() # Ensure it's on CPU and converted to numpy
375
+ # Convert potential float/int mask to boolean mask
376
+ mask = mask.astype(bool)
377
 
378
  if random_color:
379
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) # RGBA with 0.6 alpha
 
384
 
385
  # Ensure mask has H, W dimensions
386
  if mask.ndim == 3:
387
+ mask = mask.squeeze() # Remove singular dimensions like (H, W, 1)
388
  if mask.ndim != 2:
389
  print(f"Warning: show_mask received mask with shape {mask.shape}. Expected 2D.")
390
  # Create an empty transparent image if mask shape is unexpected
391
+ h, w = mask.shape[:2] if mask.ndim >= 2 else (100, 100) # Use actual shape if possible, otherwise default
392
  if convert_to_image:
393
+ return Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8), "RGBA")
394
  else:
395
+ return np.zeros((h, w, 4), dtype=np.uint8)
 
396
 
397
  h, w = mask.shape
398
  # Create an RGBA image from the mask and color
 
413
 
414
  # Removed @spaces.GPU decorator
415
  def propagate_to_all(
416
+ # We don't strictly need video_in path here anymore as it's in session_state,
417
+ # but keeping it is fine. Accessing session_state["video_path"] is more robust.
418
+ video_in,
419
  session_state,
420
  ):
421
  """Runs mask propagation through the video and generates the output video."""
 
425
  len(session_state["input_points"]) == 0 # Need at least one point
426
  or session_state["all_frames"] is None
427
  or session_state["inference_state"] is None
428
+ or session_state["video_path"] is None # Ensure we have the original video path
429
  ):
430
  print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.")
431
  return (
 
437
  # The generator yields (frame_idx, obj_ids, mask_logits)
438
  video_segments = {}
439
  try:
440
+ # This loop performs the core tracking prediction frame by frame
441
  for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
442
  session_state["inference_state"]
443
  ):
444
  # Process logits: detach from graph, move to CPU, convert to numpy boolean mask
445
  # Ensure tensor is on CPU before converting to numpy
446
  video_segments[out_frame_idx] = {
447
+ # out_mask_logits is a list of tensors (one per object tracked in this frame)
448
+ # Each tensor is [batch_size, H, W]. Batch size is 1 here.
449
+ out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy()
450
  for i, out_obj_id in enumerate(out_obj_ids)
451
  }
452
  # Optional: print progress
 
463
 
464
  output_frames = []
465
  # Iterate through all original frames to generate output video
466
+ total_frames = len(session_state["all_frames"])
467
+ for out_frame_idx in range(total_frames):
468
  original_frame_rgb = session_state["all_frames"][out_frame_idx]
469
  # Convert original frame to RGBA for compositing
470
  transparent_background = Image.fromarray(original_frame_rgb).convert("RGBA")
 
488
 
489
 
490
  # Define output path in a temporary directory
491
+ # Use os.path.join for cross-platform compatibility
492
  unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness
493
  final_vid_filename = f"output_video_{unique_id}.mp4"
 
494
  final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename)
495
  print(f"Output video path: {final_vid_output_path}")
496
 
497
 
498
  # Create a video clip from the image sequence
499
  # Get original FPS or default
500
+ # Get FPS from the stored video path in session state
501
+ original_fps = get_video_fps(session_state["video_path"])
502
  fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero
503
  print(f"Creating output video with FPS: {fps}")
504
 
 
530
  final_vid_output_path,
531
  codec="libx264",
532
  fps=fps, # Ensure correct FPS is used during writing
533
+ preset="medium", # CPU optimization: 'fast', 'faster', 'veryfast' are options for speed vs size
534
  threads="auto", # CPU optimization: Use multiple cores
535
  logger=None # Suppress moviepy output
536
  )
 
732
  ).then( # Then, run the propagation function
733
  fn=propagate_to_all,
734
  inputs=[
735
+ video_in, # Get the input video path (can also get from session_state["video_path"])
736
+ session_state, # Pass session state (contains frames, points, inference_state, video_path)
737
  ],
738
  outputs=[
739
  output_video, # Update output video player with result