# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import copy import os from datetime import datetime import gradio as gr # This line might be related to GPU, kept from original os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7" import tempfile import cv2 import matplotlib.pyplot as plt # spaces import and decorators are for Hugging Face Spaces GPU allocation, # if running locally without spaces, these can be removed or will be ignored. import spaces import numpy as np import torch from moviepy.editor import ImageSequenceClip from PIL import Image from sam2.build_sam import build_sam2_video_predictor # Description title = "
EdgeTAM [GitHub]
" description_p = """# Instructions
  1. Upload one video or click one example video
  2. Click 'include' point type, select the object to segment and track
  3. Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking
  4. Click the 'Track' button to obtain the masked video
""" # examples examples = [ ["examples/01_dog.mp4"], ["examples/02_cups.mp4"], ["examples/03_blocks.mp4"], ["examples/04_coffee.mp4"], ["examples/05_default_juggle.mp4"], ["examples/01_breakdancer.mp4"], ["examples/02_hummingbird.mp4"], ["examples/03_skateboarder.mp4"], ["examples/04_octopus.mp4"], ["examples/05_landing_dog_soccer.mp4"], ["examples/06_pingpong.mp4"], ["examples/07_snowboarder.mp4"], ["examples/08_driving.mp4"], ["examples/09_birdcartoon.mp4"], ["examples/10_cloth_magic.mp4"], ["examples/11_polevault.mp4"], ["examples/12_hideandseek.mp4"], ["examples/13_butterfly.mp4"], ["examples/14_social_dog_training.mp4"], ["examples/15_cricket.mp4"], ["examples/16_robotarm.mp4"], ["examples/17_childrendancing.mp4"], ["examples/18_threedogs.mp4"], ["examples/19_cyclist.mp4"], ["examples/20_doughkneading.mp4"], ["examples/21_biker.mp4"], ["examples/22_dogskateboarder.mp4"], ["examples/23_racecar.mp4"], ["examples/24_clownfish.mp4"], ] OBJ_ID = 0 sam2_checkpoint = "checkpoints/edgetam.pt" model_cfg = "edgetam.yaml" # Model built for CPU but immediately moved to CUDA in original code predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu") # *** Original code moves to CUDA *** predictor.to("cuda") print("predictor loaded on CUDA") # use bfloat16 for the entire demo - Original code uses CUDA bfloat16 torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() # Original CUDA settings if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True elif not torch.cuda.is_available(): print("Warning: CUDA not available. The original code is configured for GPU.") # Note: Without a GPU, the .to("cuda") calls will likely cause errors. def get_video_fps(video_path): # Open the video file cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print("Error: Could not open video.") return None # Get the FPS of the video fps = cap.get(cv2.CAP_PROP_FPS) cap.release() # Release the capture object return fps def reset(session_state): """Resets the UI and session state.""" print("Resetting demo.") session_state["input_points"] = [] session_state["input_labels"] = [] # Reset the predictor state if it exists if session_state["inference_state"] is not None: # Assuming predictor.reset_state handles None or invalid states gracefully # Or you might need to explicitly pass the state object if required try: predictor.reset_state(session_state["inference_state"]) # Explicitly delete or re-init the state object if a full reset is intended # This depends on how predictor.reset_state works # session_state["inference_state"] = None # Example if reset_state doesn't fully clear except Exception as e: print(f"Error resetting predictor state: {e}") # If reset fails, perhaps force-clear the state object session_state["inference_state"] = None session_state["first_frame"] = None session_state["all_frames"] = None session_state["inference_state"] = None # Ensure state is None after a full reset # Also reset video path if stored session_state["video_path"] = None # Resetting UI components return ( None, # video_in (clears the video player) gr.update(open=True), # video_in_drawer (opens accordion) None, # points_map (clears the image) None, # output_image (clears the image) gr.update(value=None, visible=False), # output_video (hides and clears) session_state, # return updated session state ) def clear_points(session_state): """Clears selected points and resets segmentation on the first frame.""" print("Clearing points.") session_state["input_points"] = [] session_state["input_labels"] = [] # Reset the predictor state to clear internal masks/features # This typically doesn't remove the video context, just the mask predictions if session_state["inference_state"] is not None: try: # Assuming reset_state handles clearing current masks/features predictor.reset_state(session_state["inference_state"]) print("Predictor state reset for clearing points.") # If you need to re-initialize the state for the *same* video after clearing points, # you might need to call predictor.init_state again here, using the stored video_path. # session_state["inference_state"] = predictor.init_state(video_path=session_state["video_path"], device="cuda") # Or device="cpu" if modified earlier except Exception as e: print(f"Error resetting predictor state during clear_points: {e}") # If reset fails, this might leave old masks. Depending on SAM2's behavior, # you might need a more aggressive state clear or re-initialization. # Return the original first frame image for points_map and clear the output_image first_frame_img = session_state["first_frame"] if session_state["first_frame"] is not None else None return ( first_frame_img, # points_map shows original first frame (no points yet) None, # output_image cleared (no mask) gr.update(value=None, visible=False), # output_video hidden session_state, # return updated session state ) # Added @spaces.GPU decorator back as it was in the original code @spaces.GPU def preprocess_video_in(video_path, session_state): """Loads video frames and initializes the predictor state.""" print(f"Processing video: {video_path}") if video_path is None or not os.path.exists(video_path): print("No video path provided or file not found.") # Reset state and UI elements if input is invalid # Need to return updates for the buttons as well return ( gr.update(open=True), None, None, gr.update(value=None, visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), { # Reset session state "first_frame": None, "all_frames": None, "input_points": [], "input_labels": [], "inference_state": None, "video_path": None, } ) cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print(f"Error: Could not open video file {video_path}.") return ( gr.update(open=True), None, None, gr.update(value=None, visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), { # Reset session state "first_frame": None, "all_frames": None, "input_points": [], "input_labels": [], "inference_state": None, "video_path": None, } ) first_frame = None all_frames = [] while True: ret, frame = cap.read() if not ret: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) all_frames.append(frame) if first_frame is None: first_frame = frame cap.release() if not all_frames: print(f"Error: No frames read from video file {video_path}.") return ( gr.update(open=True), None, None, gr.update(value=None, visible=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), { # Reset session state "first_frame": None, "all_frames": None, "input_points": [], "input_labels": [], "inference_state": None, "video_path": None, } ) session_state["first_frame"] = copy.deepcopy(first_frame) session_state["all_frames"] = all_frames session_state["video_path"] = video_path # Store video path session_state["input_points"] = [] session_state["input_labels"] = [] # Original code did NOT pass device here. It uses the device the predictor is on. session_state["inference_state"] = predictor.init_state(video_path=video_path) print("Video loaded and predictor state initialized.") # Enable buttons after successful load return [ gr.update(open=False), # video_in_drawer first_frame, # points_map None, # output_image gr.update(value=None, visible=False), # output_video gr.update(interactive=True), # propagate_btn gr.update(interactive=True), # clear_points_btn gr.update(interactive=True), # reset_btn session_state, # session_state ] # Added @spaces.GPU decorator back as it was in the original code @spaces.GPU def segment_with_points( point_type, session_state, evt: gr.SelectData, ): """Adds a point prompt and performs segmentation on the first frame.""" # Ensure we have state and first frame if session_state["first_frame"] is None or session_state["inference_state"] is None: print("Error: Cannot segment. No video loaded or inference state missing.") # Return current images and state without changes return ( session_state.get("first_frame"), # points_map (show first frame if exists) None, # output_image (keep cleared) session_state, ) # evt.index is the (x, y) coordinate tuple click_coords = evt.index print(f"Clicked at: {click_coords} ({point_type})") session_state["input_points"].append(click_coords) if point_type == "include": session_state["input_labels"].append(1) elif point_type == "exclude": session_state["input_labels"].append(0) # Get the first frame as a PIL image for drawing first_frame_pil = Image.fromarray(session_state["first_frame"]).convert("RGBA") w, h = first_frame_pil.size # Define the circle radius fraction = 0.01 radius = max(2, int(fraction * min(w, h))) # Ensure minimum radius of 2 # Create a transparent layer to draw points transparent_layer_points = np.zeros((h, w, 4), dtype=np.uint8) # Draw points on the transparent layer for index, track in enumerate(session_state["input_points"]): # Ensure coordinates are integers for cv2.circle point_coords = (int(track[0]), int(track[1])) # Ensure color is RGBA (0-255) if session_state["input_labels"][index] == 1: cv2.circle(transparent_layer_points, point_coords, radius, (0, 255, 0, 255), -1) # Green for include else: cv2.circle(transparent_layer_points, point_coords, radius, (255, 0, 0, 255), -1) # Red for exclude # Convert the transparent layer back to an image and composite onto the first frame transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA") # Combine the first frame image with the points layer for the points_map output # points_map shows the first frame *with the points you added*. selected_point_map_img = Image.alpha_composite( first_frame_pil.copy(), transparent_layer_points_pil ) # Prepare points and labels as tensors on the correct device (CUDA in original code) points = np.array(session_state["input_points"], dtype=np.float32) labels = np.array(session_state["input_labels"], np.int32) # Ensure tensors are on the correct device (CUDA as per original code setup) device = next(predictor.parameters()).device # Get the device the model is on points_tensor = torch.tensor(points, dtype=torch.float32, device=device).unsqueeze(0) # Add batch dim labels_tensor = torch.tensor(labels, dtype=torch.int32, device=device).unsqueeze(0) # Add batch dim # Add new points to the predictor's state and get the mask for the first frame # This call performs segmentation on the current frame (frame_idx=0) using all accumulated points first_frame_output_img = None # Initialize output mask image as None in case of error try: # Note: predictor.add_new_points modifies the internal inference_state _, _, out_mask_logits = predictor.add_new_points( inference_state=session_state["inference_state"], frame_idx=0, # Always segment on the first frame initially obj_id=OBJ_ID, points=points_tensor, labels=labels_tensor, ) # Process logits: detach from graph, move to CPU, apply threshold # out_mask_logits is a list of tensors [tensor([batch_size, H, W])] for the requested obj_id # Access the result for the first object (index 0) and the first item in batch (index 0) mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Move to CPU before converting to numpy mask_numpy = mask_tensor.numpy() # Convert to numpy # Get the mask image (RGBA) mask_image_pil = show_mask(mask_numpy, obj_id=OBJ_ID) # show_mask returns RGBA PIL Image # Composite the mask onto the first frame for the output_image # output_image shows the first frame *with the segmentation mask result*. first_frame_output_img = Image.alpha_composite(first_frame_pil.copy(), mask_image_pil) except Exception as e: print(f"Error during segmentation on first frame: {e}") # On error, first_frame_output_img remains None # Original code clears CUDA cache here if torch.cuda.is_available(): torch.cuda.empty_cache() return selected_point_map_img, first_frame_output_img, session_state def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True): """Helper function to visualize a mask.""" # Ensure mask is a numpy array (and boolean) if isinstance(mask, torch.Tensor): mask = mask.detach().cpu().numpy() # Ensure it's on CPU and converted to numpy # Convert potential float/int mask to boolean mask mask = mask.astype(bool) if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) # RGBA with 0.6 alpha else: cmap = plt.get_cmap("tab10") cmap_idx = 0 if obj_id is None else obj_id % 10 # Use modulo 10 for tab10 colors color = np.array([*cmap(cmap_idx)[:3], 0.6]) # RGBA with 0.6 alpha # Ensure mask has H, W dimensions if mask.ndim == 3: mask = mask.squeeze() # Remove singular dimensions like (H, W, 1) if mask.ndim != 2: print(f"Warning: show_mask received mask with shape {mask.shape}. Expected 2D.") # Create an empty transparent image if mask shape is unexpected h, w = mask.shape[:2] if mask.ndim >= 2 else (100, 100) # Use actual shape if possible, otherwise default if convert_to_image: return Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8), "RGBA") else: return np.zeros((h, w, 4), dtype=np.uint8) h, w = mask.shape # Create an RGBA image from the mask and color # Apply color where mask is True # Need to reshape color to be broadcastable [1, 1, 4] colored_mask = np.zeros((h, w, 4), dtype=np.float32) # Start with fully transparent black # Apply the color only where the mask is True. # This directly creates the colored overlay with transparency. colored_mask[mask] = color # Convert to uint8 [0-255] colored_mask_uint8 = (colored_mask * 255).astype(np.uint8) if convert_to_image: mask_img = Image.fromarray(colored_mask_uint8, "RGBA") return mask_img else: return colored_mask_uint8 # Added @spaces.GPU decorator back as it was in the original code @spaces.GPU def propagate_to_all( video_in, # Keep video_in path as in original session_state, ): """Runs mask propagation through the video and generates the output video.""" print("Starting propagation...") # Ensure state is ready # Using session_state.get("video_path") is safer than video_in directly current_video_path = session_state.get("video_path") if ( len(session_state["input_points"]) == 0 # Need at least one point or session_state["all_frames"] is None or session_state["inference_state"] is None or current_video_path is None # Ensure we have the original video path ): print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.") return ( gr.update(value=None, visible=False), # Hide output video on error session_state, ) # run propagation throughout the video and collect the results video_segments = {} try: # This loop performs the core tracking prediction frame by frame for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( session_state["inference_state"] ): # Process logits: detach from graph, move to CPU, convert to numpy boolean mask # Ensure tensor is on CPU before converting to numpy video_segments[out_frame_idx] = { # out_mask_logits is a list of tensors (one per object tracked in this frame) # Each tensor is [batch_size, H, W]. Batch size is 1 here. # Access the result for the first object (index i) and the first item in batch (index 0) out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy() for i, out_obj_id in enumerate(out_obj_ids) } # Optional: print progress # print(f"Processed frame {out_frame_idx+1}/{len(session_state['all_frames'])}") print("Propagation finished.") except Exception as e: print(f"Error during propagation: {e}") return ( gr.update(value=None, visible=False), # Hide output video on error session_state, ) output_frames = [] # Iterate through all original frames to generate output video total_frames = len(session_state["all_frames"]) for out_frame_idx in range(total_frames): original_frame_rgb = session_state["all_frames"][out_frame_idx] # Convert original frame to RGBA for compositing transparent_background = Image.fromarray(original_frame_rgb).convert("RGBA") # Check if we have a mask for this frame and object ID if out_frame_idx in video_segments and OBJ_ID in video_segments[out_frame_idx]: current_mask_numpy = video_segments[out_frame_idx][OBJ_ID] # Get the mask image (RGBA) mask_image_pil = show_mask(current_mask_numpy, obj_id=OBJ_ID) # Composite the mask onto the frame output_frame_img_rgba = Image.alpha_composite(transparent_background, mask_image_pil) # Convert back to numpy RGB (moviepy needs RGB or RGBA) output_frame_np = np.array(output_frame_img_rgba.convert("RGB")) else: # If no mask for this frame/object, just use the original frame (converted to RGB) # Note: all_frames are already RGB numpy arrays, so just use them directly. # print(f"Warning: No mask found for frame {out_frame_idx} and object {OBJ_ID}. Using original frame.") output_frame_np = original_frame_rgb # Already RGB numpy array output_frames.append(output_frame_np) # Original code clears CUDA cache here if torch.cuda.is_available(): torch.cuda.empty_cache() # Define output path in a temporary directory unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness final_vid_filename = f"output_video_{unique_id}.mp4" final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename) print(f"Output video path: {final_vid_output_path}") # Create a video clip from the image sequence # Get original FPS from the stored video path original_fps = get_video_fps(current_video_path) fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero print(f"Creating output video with FPS: {fps}") # Check if there are frames to process if not output_frames: print("No output frames generated.") return ( gr.update(value=None, visible=False), # Hide output video session_state, ) # Create ImageSequenceClip from the list of numpy arrays try: clip = ImageSequenceClip(output_frames, fps=fps) except Exception as e: print(f"Error creating ImageSequenceClip: {e}") return ( gr.update(value=None, visible=False), # Hide output video on error session_state, ) # Write the result to a file. Use 'libx264' codec for broad compatibility. try: print(f"Writing video file with codec='libx264', fps={fps}") # Added basic moviepy writing parameters back, similar to original intent clip.write_videofile(final_vid_output_path, codec="libx264", fps=fps) print("Video writing complete.") # Return the path and make the video player visible return ( gr.update(value=final_vid_output_path, visible=True), session_state, ) except Exception as e: print(f"Error writing video file: {e}") # Clean up potentially created partial file if os.path.exists(final_vid_output_path): try: os.remove(final_vid_output_path) print(f"Removed partial video file: {final_vid_output_path}") except Exception as clean_e: print(f"Error removing partial file: {clean_e}") # Return None if writing fails return ( gr.update(value=None, visible=False), session_state, ) def update_ui(): """Simply returns a Gradio update to make the output video visible.""" return gr.update(visible=True) with gr.Blocks() as demo: # Session state dictionary to hold video frames, points, labels, and predictor state session_state = gr.State( { "first_frame": None, # numpy array (RGB) "all_frames": None, # list of numpy arrays (RGB) "input_points": [], # list of (x, y) tuples/lists "input_labels": [], # list of 1s and 0s "inference_state": None, # EdgeTAM predictor state object "video_path": None, # Store the input video path } ) with gr.Column(): # Title gr.Markdown(title) with gr.Row(): with gr.Column(): # Instructions gr.Markdown(description_p) with gr.Accordion("Input Video", open=True) as video_in_drawer: video_in = gr.Video(label="Input Video", format="mp4") # Will hold the video file path with gr.Row(): point_type = gr.Radio( label="point type", choices=["include", "exclude"], value="include", scale=2, interactive=True, # Make interactive ) # Buttons are initially disabled until a video is loaded propagate_btn = gr.Button("Track", scale=1, variant="primary", interactive=False) clear_points_btn = gr.Button("Clear Points", scale=1, interactive=False) reset_btn = gr.Button("Reset", scale=1, interactive=False) # points_map is where users click to add points. Needs to be interactive. # Shows the first frame with points drawn on it. points_map = gr.Image( label="Click on the First Frame to Add Points", # Clearer label type="numpy", interactive=True, # <--- THIS WAS CHANGED FROM False TO True height=400, # Set a fixed height for better UI width="auto", # Let width adjust show_share_button=False, show_download_button=False, ) with gr.Column(): gr.Markdown("# Try some of the examples below ⬇️") gr.Examples( examples=examples, inputs=[video_in], examples_per_page=8, cache_examples=False, # Do not cache processed examples, as state is involved ) # Add padding/space - removed extra lines as they take up a lot of space # gr.Markdown("
") # output_image shows the segmentation mask prediction on the *first* frame output_image = gr.Image( label="Segmentation Mask on First Frame", # Clearer label type="numpy", interactive=False, # Not interactive, just displays the mask height=400, # Match height of points_map width="auto", # Let width adjust show_share_button=False, show_download_button=False, ) # output_video shows the final tracking result output_video = gr.Video(visible=False, label="Tracking Result") # --- Event Handlers --- # When a new video file is uploaded via the file browser # Added postprocess to update button interactivity based on whether video loaded video_in.upload( fn=preprocess_video_in, inputs=[video_in, session_state], outputs=[ video_in_drawer, points_map, output_image, output_video, propagate_btn, clear_points_btn, reset_btn, session_state, ], queue=False, # Process immediately ) # When an example video is selected (change event) # Added postprocess to update button interactivity video_in.change( fn=preprocess_video_in, inputs=[video_in, session_state], outputs=[ video_in_drawer, points_map, output_image, output_video, propagate_btn, clear_points_btn, reset_btn, session_state, ], queue=False, # Process immediately ) # Triggered when a user clicks on the points_map image points_map.select( fn=segment_with_points, inputs=[ point_type, # "include" or "exclude" radio button value session_state, # Pass session state ], outputs=[ points_map, # Updated image with points drawn output_image, # Updated image with first frame segmentation mask session_state, # Updated session state (points/labels added) ], queue=False, # Process clicks immediately ) # Button to clear all selected points and reset the first frame mask clear_points_btn.click( fn=clear_points, inputs=[session_state], # Pass session state outputs=[ points_map, # points_map shows original first frame without points output_image, # output_image cleared (or shows original first frame without mask) output_video, # Hide output video session_state, # Updated session state (points/labels cleared, inference state reset) ], queue=False, # Process immediately ) # Button to reset the entire demo state and UI reset_btn.click( fn=reset, inputs=[session_state], # Pass session state outputs=[ video_in, video_in_drawer, points_map, output_image, output_video, propagate_btn, clear_points_btn, reset_btn, session_state, ], queue=False, # Process immediately ) # Button to start mask propagation through the video propagate_btn.click( fn=update_output_video_visibility, # First, make the output video player visible inputs=[], outputs=[output_video], queue=False, # Process this UI update immediately ).then( # Then, run the propagation function fn=propagate_to_all, inputs=[ video_in, # Get the input video path (can also get from session_state["video_path"]) session_state, # Pass session state (contains frames, points, inference_state, video_path) ], outputs=[ output_video, # Update output video player with result session_state, # Update session state ], # concurrency_limit from original code (may need adjustment based on your hardware/GPU) concurrency_limit=10, queue=False, # queue from original code ) # Launch the Gradio demo demo.queue() # Enable queuing print("Gradio demo starting...") demo.launch() print("Gradio demo launched.")