Update app.py
Browse files
app.py
CHANGED
@@ -76,8 +76,9 @@ OBJ_ID = 0
|
|
76 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
77 |
model_cfg = "edgetam.yaml"
|
78 |
# Ensure predictor is explicitly built for CPU
|
|
|
79 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
80 |
-
predictor.to("cpu") # Explicitly move to CPU
|
81 |
print("predictor loaded on CPU")
|
82 |
|
83 |
# Removed autocast block for maximum CPU compatibility
|
@@ -121,6 +122,7 @@ def preprocess_video_in(video_path, session_state):
|
|
121 |
"input_points": [],
|
122 |
"input_labels": [],
|
123 |
"inference_state": None,
|
|
|
124 |
}
|
125 |
)
|
126 |
|
@@ -143,6 +145,7 @@ def preprocess_video_in(video_path, session_state):
|
|
143 |
"input_points": [],
|
144 |
"input_labels": [],
|
145 |
"inference_state": None,
|
|
|
146 |
}
|
147 |
)
|
148 |
|
@@ -178,16 +181,18 @@ def preprocess_video_in(video_path, session_state):
|
|
178 |
"input_points": [],
|
179 |
"input_labels": [],
|
180 |
"inference_state": None,
|
|
|
181 |
}
|
182 |
)
|
183 |
|
184 |
-
|
185 |
session_state["first_frame"] = copy.deepcopy(first_frame) # Store a copy
|
186 |
session_state["all_frames"] = all_frames
|
|
|
187 |
session_state["input_points"] = []
|
188 |
session_state["input_labels"] = []
|
189 |
-
# Initialize state
|
190 |
-
session_state["inference_state"] = predictor.init_state(video_path=video_path
|
191 |
print("Video loaded and predictor state initialized.")
|
192 |
|
193 |
return [
|
@@ -213,9 +218,10 @@ def reset(session_state):
|
|
213 |
predictor.reset_state(session_state["inference_state"])
|
214 |
# After reset, we also discard the state object as a new video might be loaded
|
215 |
session_state["inference_state"] = None
|
216 |
-
# Clear frames
|
217 |
session_state["first_frame"] = None
|
218 |
session_state["all_frames"] = None
|
|
|
219 |
|
220 |
# Update UI elements to their initial state
|
221 |
return (
|
@@ -238,18 +244,19 @@ def clear_points(session_state):
|
|
238 |
session_state["input_points"] = []
|
239 |
session_state["input_labels"] = []
|
240 |
|
241 |
-
#
|
242 |
# but keeps the video context initialized by preprocess_video_in.
|
243 |
if session_state["inference_state"] is not None:
|
244 |
predictor.reset_state(session_state["inference_state"])
|
245 |
-
# After resetting the state, we
|
246 |
-
#
|
247 |
-
if
|
248 |
-
|
|
|
|
|
249 |
else:
|
250 |
-
# This case should ideally not happen if preprocess_video_in ran correctly
|
251 |
print("Warning: Could not re-initialize state after clear_points (video_path missing).")
|
252 |
-
session_state["inference_state"] = None
|
253 |
|
254 |
|
255 |
# Re-render the points_map with no points drawn (just the first frame)
|
@@ -324,6 +331,7 @@ def segment_with_points(
|
|
324 |
points = np.array(session_state["input_points"], dtype=np.float32)
|
325 |
labels = np.array(session_state["input_labels"], np.int32)
|
326 |
|
|
|
327 |
points_tensor = torch.tensor(points, dtype=torch.float32, device="cpu").unsqueeze(0) # Add batch dim
|
328 |
labels_tensor = torch.tensor(labels, dtype=torch.int32, device="cpu").unsqueeze(0) # Add batch dim
|
329 |
|
@@ -340,6 +348,7 @@ def segment_with_points(
|
|
340 |
|
341 |
# Process logits: detach from graph, move to CPU, apply threshold
|
342 |
# out_mask_logits is [batch_size, H, W] (batch_size=1 here)
|
|
|
343 |
mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Apply threshold and get the single mask tensor [H, W]
|
344 |
mask_numpy = mask_tensor.numpy() # Convert to numpy
|
345 |
|
@@ -363,7 +372,8 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
|
363 |
# Ensure mask is a numpy array (and boolean)
|
364 |
if isinstance(mask, torch.Tensor):
|
365 |
mask = mask.detach().cpu().numpy() # Ensure it's on CPU and converted to numpy
|
366 |
-
|
|
|
367 |
|
368 |
if random_color:
|
369 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) # RGBA with 0.6 alpha
|
@@ -374,15 +384,15 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
|
374 |
|
375 |
# Ensure mask has H, W dimensions
|
376 |
if mask.ndim == 3:
|
377 |
-
mask = mask.squeeze() # Remove singular dimensions
|
378 |
if mask.ndim != 2:
|
379 |
print(f"Warning: show_mask received mask with shape {mask.shape}. Expected 2D.")
|
380 |
# Create an empty transparent image if mask shape is unexpected
|
|
|
381 |
if convert_to_image:
|
382 |
-
return Image.fromarray(np.zeros((
|
383 |
else:
|
384 |
-
return np.zeros((
|
385 |
-
|
386 |
|
387 |
h, w = mask.shape
|
388 |
# Create an RGBA image from the mask and color
|
@@ -403,7 +413,9 @@ def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
|
|
403 |
|
404 |
# Removed @spaces.GPU decorator
|
405 |
def propagate_to_all(
|
406 |
-
|
|
|
|
|
407 |
session_state,
|
408 |
):
|
409 |
"""Runs mask propagation through the video and generates the output video."""
|
@@ -413,6 +425,7 @@ def propagate_to_all(
|
|
413 |
len(session_state["input_points"]) == 0 # Need at least one point
|
414 |
or session_state["all_frames"] is None
|
415 |
or session_state["inference_state"] is None
|
|
|
416 |
):
|
417 |
print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.")
|
418 |
return (
|
@@ -424,13 +437,16 @@ def propagate_to_all(
|
|
424 |
# The generator yields (frame_idx, obj_ids, mask_logits)
|
425 |
video_segments = {}
|
426 |
try:
|
|
|
427 |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
|
428 |
session_state["inference_state"]
|
429 |
):
|
430 |
# Process logits: detach from graph, move to CPU, convert to numpy boolean mask
|
431 |
# Ensure tensor is on CPU before converting to numpy
|
432 |
video_segments[out_frame_idx] = {
|
433 |
-
|
|
|
|
|
434 |
for i, out_obj_id in enumerate(out_obj_ids)
|
435 |
}
|
436 |
# Optional: print progress
|
@@ -447,7 +463,8 @@ def propagate_to_all(
|
|
447 |
|
448 |
output_frames = []
|
449 |
# Iterate through all original frames to generate output video
|
450 |
-
|
|
|
451 |
original_frame_rgb = session_state["all_frames"][out_frame_idx]
|
452 |
# Convert original frame to RGBA for compositing
|
453 |
transparent_background = Image.fromarray(original_frame_rgb).convert("RGBA")
|
@@ -471,16 +488,17 @@ def propagate_to_all(
|
|
471 |
|
472 |
|
473 |
# Define output path in a temporary directory
|
|
|
474 |
unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness
|
475 |
final_vid_filename = f"output_video_{unique_id}.mp4"
|
476 |
-
# Use os.path.join for cross-platform compatibility
|
477 |
final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename)
|
478 |
print(f"Output video path: {final_vid_output_path}")
|
479 |
|
480 |
|
481 |
# Create a video clip from the image sequence
|
482 |
# Get original FPS or default
|
483 |
-
|
|
|
484 |
fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero
|
485 |
print(f"Creating output video with FPS: {fps}")
|
486 |
|
@@ -512,7 +530,7 @@ def propagate_to_all(
|
|
512 |
final_vid_output_path,
|
513 |
codec="libx264",
|
514 |
fps=fps, # Ensure correct FPS is used during writing
|
515 |
-
preset="medium", # CPU optimization: 'fast', 'faster', 'veryfast' are options for speed
|
516 |
threads="auto", # CPU optimization: Use multiple cores
|
517 |
logger=None # Suppress moviepy output
|
518 |
)
|
@@ -714,8 +732,8 @@ with gr.Blocks() as demo:
|
|
714 |
).then( # Then, run the propagation function
|
715 |
fn=propagate_to_all,
|
716 |
inputs=[
|
717 |
-
video_in, # Get the input video path
|
718 |
-
session_state, # Pass session state (contains frames, points, inference_state)
|
719 |
],
|
720 |
outputs=[
|
721 |
output_video, # Update output video player with result
|
|
|
76 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
77 |
model_cfg = "edgetam.yaml"
|
78 |
# Ensure predictor is explicitly built for CPU
|
79 |
+
# The device is set here and with .to("cpu")
|
80 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
81 |
+
predictor.to("cpu") # Explicitly move to CPU after building
|
82 |
print("predictor loaded on CPU")
|
83 |
|
84 |
# Removed autocast block for maximum CPU compatibility
|
|
|
122 |
"input_points": [],
|
123 |
"input_labels": [],
|
124 |
"inference_state": None,
|
125 |
+
"video_path": None,
|
126 |
}
|
127 |
)
|
128 |
|
|
|
145 |
"input_points": [],
|
146 |
"input_labels": [],
|
147 |
"inference_state": None,
|
148 |
+
"video_path": None,
|
149 |
}
|
150 |
)
|
151 |
|
|
|
181 |
"input_points": [],
|
182 |
"input_labels": [],
|
183 |
"inference_state": None,
|
184 |
+
"video_path": None,
|
185 |
}
|
186 |
)
|
187 |
|
188 |
+
# Update session state with frames and path
|
189 |
session_state["first_frame"] = copy.deepcopy(first_frame) # Store a copy
|
190 |
session_state["all_frames"] = all_frames
|
191 |
+
session_state["video_path"] = video_path # Store the path
|
192 |
session_state["input_points"] = []
|
193 |
session_state["input_labels"] = []
|
194 |
+
# Initialize state *without* the device argument
|
195 |
+
session_state["inference_state"] = predictor.init_state(video_path=video_path)
|
196 |
print("Video loaded and predictor state initialized.")
|
197 |
|
198 |
return [
|
|
|
218 |
predictor.reset_state(session_state["inference_state"])
|
219 |
# After reset, we also discard the state object as a new video might be loaded
|
220 |
session_state["inference_state"] = None
|
221 |
+
# Clear frames and video path
|
222 |
session_state["first_frame"] = None
|
223 |
session_state["all_frames"] = None
|
224 |
+
session_state["video_path"] = None
|
225 |
|
226 |
# Update UI elements to their initial state
|
227 |
return (
|
|
|
244 |
session_state["input_points"] = []
|
245 |
session_state["input_labels"] = []
|
246 |
|
247 |
+
# Reset the predictor state if it exists. This clears internal masks/features
|
248 |
# but keeps the video context initialized by preprocess_video_in.
|
249 |
if session_state["inference_state"] is not None:
|
250 |
predictor.reset_state(session_state["inference_state"])
|
251 |
+
# After resetting the state, if we still have the video path, re-initialize the state
|
252 |
+
# to be ready for new points on the same video.
|
253 |
+
if session_state["video_path"] is not None:
|
254 |
+
# Re-initialize state *without* the device argument
|
255 |
+
session_state["inference_state"] = predictor.init_state(video_path=session_state["video_path"])
|
256 |
+
print("Predictor state re-initialized after clearing points.")
|
257 |
else:
|
|
|
258 |
print("Warning: Could not re-initialize state after clear_points (video_path missing).")
|
259 |
+
session_state["inference_state"] = None # Ensure state is None if video_path is gone
|
260 |
|
261 |
|
262 |
# Re-render the points_map with no points drawn (just the first frame)
|
|
|
331 |
points = np.array(session_state["input_points"], dtype=np.float32)
|
332 |
labels = np.array(session_state["input_labels"], np.int32)
|
333 |
|
334 |
+
# Ensure tensors are on CPU
|
335 |
points_tensor = torch.tensor(points, dtype=torch.float32, device="cpu").unsqueeze(0) # Add batch dim
|
336 |
labels_tensor = torch.tensor(labels, dtype=torch.int32, device="cpu").unsqueeze(0) # Add batch dim
|
337 |
|
|
|
348 |
|
349 |
# Process logits: detach from graph, move to CPU, apply threshold
|
350 |
# out_mask_logits is [batch_size, H, W] (batch_size=1 here)
|
351 |
+
# out_mask_logits[0] is the tensor for obj_id=OBJ_ID
|
352 |
mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Apply threshold and get the single mask tensor [H, W]
|
353 |
mask_numpy = mask_tensor.numpy() # Convert to numpy
|
354 |
|
|
|
372 |
# Ensure mask is a numpy array (and boolean)
|
373 |
if isinstance(mask, torch.Tensor):
|
374 |
mask = mask.detach().cpu().numpy() # Ensure it's on CPU and converted to numpy
|
375 |
+
# Convert potential float/int mask to boolean mask
|
376 |
+
mask = mask.astype(bool)
|
377 |
|
378 |
if random_color:
|
379 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) # RGBA with 0.6 alpha
|
|
|
384 |
|
385 |
# Ensure mask has H, W dimensions
|
386 |
if mask.ndim == 3:
|
387 |
+
mask = mask.squeeze() # Remove singular dimensions like (H, W, 1)
|
388 |
if mask.ndim != 2:
|
389 |
print(f"Warning: show_mask received mask with shape {mask.shape}. Expected 2D.")
|
390 |
# Create an empty transparent image if mask shape is unexpected
|
391 |
+
h, w = mask.shape[:2] if mask.ndim >= 2 else (100, 100) # Use actual shape if possible, otherwise default
|
392 |
if convert_to_image:
|
393 |
+
return Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8), "RGBA")
|
394 |
else:
|
395 |
+
return np.zeros((h, w, 4), dtype=np.uint8)
|
|
|
396 |
|
397 |
h, w = mask.shape
|
398 |
# Create an RGBA image from the mask and color
|
|
|
413 |
|
414 |
# Removed @spaces.GPU decorator
|
415 |
def propagate_to_all(
|
416 |
+
# We don't strictly need video_in path here anymore as it's in session_state,
|
417 |
+
# but keeping it is fine. Accessing session_state["video_path"] is more robust.
|
418 |
+
video_in,
|
419 |
session_state,
|
420 |
):
|
421 |
"""Runs mask propagation through the video and generates the output video."""
|
|
|
425 |
len(session_state["input_points"]) == 0 # Need at least one point
|
426 |
or session_state["all_frames"] is None
|
427 |
or session_state["inference_state"] is None
|
428 |
+
or session_state["video_path"] is None # Ensure we have the original video path
|
429 |
):
|
430 |
print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.")
|
431 |
return (
|
|
|
437 |
# The generator yields (frame_idx, obj_ids, mask_logits)
|
438 |
video_segments = {}
|
439 |
try:
|
440 |
+
# This loop performs the core tracking prediction frame by frame
|
441 |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
|
442 |
session_state["inference_state"]
|
443 |
):
|
444 |
# Process logits: detach from graph, move to CPU, convert to numpy boolean mask
|
445 |
# Ensure tensor is on CPU before converting to numpy
|
446 |
video_segments[out_frame_idx] = {
|
447 |
+
# out_mask_logits is a list of tensors (one per object tracked in this frame)
|
448 |
+
# Each tensor is [batch_size, H, W]. Batch size is 1 here.
|
449 |
+
out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy()
|
450 |
for i, out_obj_id in enumerate(out_obj_ids)
|
451 |
}
|
452 |
# Optional: print progress
|
|
|
463 |
|
464 |
output_frames = []
|
465 |
# Iterate through all original frames to generate output video
|
466 |
+
total_frames = len(session_state["all_frames"])
|
467 |
+
for out_frame_idx in range(total_frames):
|
468 |
original_frame_rgb = session_state["all_frames"][out_frame_idx]
|
469 |
# Convert original frame to RGBA for compositing
|
470 |
transparent_background = Image.fromarray(original_frame_rgb).convert("RGBA")
|
|
|
488 |
|
489 |
|
490 |
# Define output path in a temporary directory
|
491 |
+
# Use os.path.join for cross-platform compatibility
|
492 |
unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness
|
493 |
final_vid_filename = f"output_video_{unique_id}.mp4"
|
|
|
494 |
final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename)
|
495 |
print(f"Output video path: {final_vid_output_path}")
|
496 |
|
497 |
|
498 |
# Create a video clip from the image sequence
|
499 |
# Get original FPS or default
|
500 |
+
# Get FPS from the stored video path in session state
|
501 |
+
original_fps = get_video_fps(session_state["video_path"])
|
502 |
fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero
|
503 |
print(f"Creating output video with FPS: {fps}")
|
504 |
|
|
|
530 |
final_vid_output_path,
|
531 |
codec="libx264",
|
532 |
fps=fps, # Ensure correct FPS is used during writing
|
533 |
+
preset="medium", # CPU optimization: 'fast', 'faster', 'veryfast' are options for speed vs size
|
534 |
threads="auto", # CPU optimization: Use multiple cores
|
535 |
logger=None # Suppress moviepy output
|
536 |
)
|
|
|
732 |
).then( # Then, run the propagation function
|
733 |
fn=propagate_to_all,
|
734 |
inputs=[
|
735 |
+
video_in, # Get the input video path (can also get from session_state["video_path"])
|
736 |
+
session_state, # Pass session state (contains frames, points, inference_state, video_path)
|
737 |
],
|
738 |
outputs=[
|
739 |
output_video, # Update output video player with result
|