# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import copy import os from datetime import datetime import gradio as gr # Removed GPU-specific environment variable setting # os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7" import tempfile import cv2 import matplotlib.pyplot as plt import numpy as np # Removed spaces decorator import for CPU-only demo # import spaces import torch from moviepy.editor import ImageSequenceClip from PIL import Image from sam2.build_sam import build_sam2_video_predictor # Description title = "
EdgeTAM [GitHub]
" description_p = """# Instructions
  1. Upload one video or click one example video
  2. Click 'include' point type, select the object to segment and track
  3. Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking
  4. Click the 'Track' button to obtain the masked video
""" # examples - Keep examples, they are input files examples = [ ["examples/01_dog.mp4"], ["examples/02_cups.mp4"], ["examples/03_blocks.mp4"], ["examples/04_coffee.mp4"], ["examples/05_default_juggle.mp4"], ["examples/01_breakdancer.mp4"], ["examples/02_hummingbird.mp4"], ["examples/03_skateboarder.mp4"], ["examples/04_octopus.mp4"], ["examples/05_landing_dog_soccer.mp4"], ["examples/06_pingpong.mp4"], ["examples/07_snowboarder.mp4"], ["examples/08_driving.mp4"], ["examples/09_birdcartoon.mp4"], ["examples/10_cloth_magic.mp4"], ["examples/11_polevault.mp4"], ["examples/12_hideandseek.mp4"], ["examples/13_butterfly.mp4"], ["examples/14_social_dog_training.mp4"], ["examples/15_cricket.mp4"], ["examples/16_robotarm.mp4"], ["examples/17_childrendancing.mp4"], ["examples/18_threedogs.mp4"], ["examples/19_cyclist.mp4"], ["examples/20_doughkneading.mp4"], ["examples/21_biker.mp4"], ["examples/22_dogskateboarder.mp4"], ["examples/23_racecar.mp4"], ["examples/24_clownfish.mp4"], ] OBJ_ID = 0 sam2_checkpoint = "checkpoints/edgetam.pt" model_cfg = "edgetam.yaml" # Ensure predictor is explicitly built for CPU predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu") predictor.to("cpu") # Explicitly move to CPU, though device="cpu" should handle it print("predictor loaded on CPU") # Removed autocast block for maximum CPU compatibility # torch.autocast(device_type="cpu", dtype=torch.bfloat16).__enter__() # Removed commented-out GPU-specific code # if torch.cuda.get_device_properties(0).major >= 8: ... def get_video_fps(video_path): """Gets the frames per second of a video file.""" if video_path is None or not os.path.exists(video_path): print(f"Warning: Video file not found at {video_path}") return None cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print(f"Error: Could not open video file {video_path}.") return None fps = cap.get(cv2.CAP_PROP_FPS) cap.release() return fps # Removed @spaces.GPU decorator def preprocess_video_in(video_path, session_state): """Loads video frames and initializes the predictor state.""" print(f"Processing video: {video_path}") if video_path is None or not os.path.exists(video_path): print("No video path provided or file not found.") # Reset state and UI elements if input is invalid return ( gr.update(open=True), # video_in_drawer None, # points_map None, # output_image gr.update(value=None, visible=False), # output_video gr.update(interactive=False), # propagate_btn gr.update(interactive=False), # clear_points_btn gr.update(interactive=False), # reset_btn { # Reset session state "first_frame": None, "all_frames": None, "input_points": [], "input_labels": [], "inference_state": None, } ) # Read the first frame and all frames cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print(f"Error: Could not open video file {video_path}.") # Reset state and UI elements on error return ( gr.update(open=True), None, None, gr.update(value=None, visible=False), gr.update(interactive=False), # propagate_btn gr.update(interactive=False), # clear_points_btn gr.update(interactive=False), # reset_btn { # Reset session state "first_frame": None, "all_frames": None, "input_points": [], "input_labels": [], "inference_state": None, } ) first_frame = None all_frames = [] while True: ret, frame = cap.read() if not ret: break # Convert BGR to RGB frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) all_frames.append(frame) if first_frame is None: first_frame = frame # Store the first frame cap.release() if not all_frames: print(f"Error: No frames read from video file {video_path}.") # Reset state and UI elements if no frames are read return ( gr.update(open=True), None, None, gr.update(value=None, visible=False), gr.update(interactive=False), # propagate_btn gr.update(interactive=False), # clear_points_btn gr.update(interactive=False), # reset_btn { # Reset session state "first_frame": None, "all_frames": None, "input_points": [], "input_labels": [], "inference_state": None, } ) session_state["first_frame"] = copy.deepcopy(first_frame) # Store a copy session_state["all_frames"] = all_frames session_state["input_points"] = [] session_state["input_labels"] = [] # Initialize state explicitly for CPU session_state["inference_state"] = predictor.init_state(video_path=video_path, device="cpu") print("Video loaded and predictor state initialized.") return [ gr.update(open=False), # video_in_drawer first_frame, # points_map (shows first frame) None, # output_image (cleared initially) gr.update(value=None, visible=False), # output_video (hidden initially) gr.update(interactive=True), # Enable buttons gr.update(interactive=True), # Enable buttons gr.update(interactive=True), # Enable buttons session_state, # Updated state ] def reset(session_state): """Resets the UI and session state.""" print("Resetting demo.") # Clear points and labels session_state["input_points"] = [] session_state["input_labels"] = [] # Reset the predictor state if it exists if session_state["inference_state"] is not None: predictor.reset_state(session_state["inference_state"]) # After reset, we also discard the state object as a new video might be loaded session_state["inference_state"] = None # Clear frames session_state["first_frame"] = None session_state["all_frames"] = None # Update UI elements to their initial state return ( None, # video_in gr.update(open=True), # video_in_drawer open None, # points_map cleared None, # output_image cleared gr.update(value=None, visible=False), # output_video hidden gr.update(interactive=False), # Disable buttons gr.update(interactive=False), # Disable buttons gr.update(interactive=False), # Disable buttons session_state, # Updated session state ) def clear_points(session_state): """Clears selected points and resets segmentation on the first frame.""" print("Clearing points.") # Clear points and labels lists session_state["input_points"] = [] session_state["input_labels"] = [] # If inference state exists, reset it. This clears internal masks/features # but keeps the video context initialized by preprocess_video_in. if session_state["inference_state"] is not None: predictor.reset_state(session_state["inference_state"]) # After resetting the state, we need to re-initialize it to be ready for new points. # Pass the original video path stored in the state. if "video_path" in session_state["inference_state"] and session_state["inference_state"]["video_path"] is not None: session_state["inference_state"] = predictor.init_state(video_path=session_state["inference_state"]["video_path"], device="cpu") else: # This case should ideally not happen if preprocess_video_in ran correctly print("Warning: Could not re-initialize state after clear_points (video_path missing).") session_state["inference_state"] = None # Re-render the points_map with no points drawn (just the first frame) # Re-render the output_image with no mask (just the first frame) first_frame_img = session_state["first_frame"] if session_state["first_frame"] is not None else None return ( first_frame_img, # points_map shows original first frame None, # output_image cleared gr.update(value=None, visible=False), # Hide output video session_state, # Updated session state ) # Removed @spaces.GPU decorator def segment_with_points( point_type, session_state, evt: gr.SelectData, ): """Adds a point prompt and performs segmentation on the first frame.""" # Ensure we have a valid first frame and inference state if session_state["first_frame"] is None or session_state["inference_state"] is None: print("Error: Cannot segment. No video loaded or inference state missing.") return ( session_state["first_frame"], # points_map remains unchanged None, # output_image remains unchanged or cleared session_state, ) # evt.index gives the (x, y) coordinates of the click click_coords = evt.index print(f"Clicked at: {click_coords} ({point_type})") session_state["input_points"].append(click_coords) if point_type == "include": session_state["input_labels"].append(1) elif point_type == "exclude": session_state["input_labels"].append(0) # Get the first frame as a PIL image for drawing first_frame_pil = Image.fromarray(session_state["first_frame"]).convert("RGBA") w, h = first_frame_pil.size # Define the circle radius fraction = 0.01 radius = max(2, int(fraction * min(w, h))) # Ensure minimum radius of 2 # Create a transparent layer to draw points transparent_layer_points = np.zeros((h, w, 4), dtype=np.uint8) # Draw points on the transparent layer for index, track in enumerate(session_state["input_points"]): # Ensure coordinates are integers for cv2.circle point_coords = (int(track[0]), int(track[1])) if session_state["input_labels"][index] == 1: # Green circle for include cv2.circle(transparent_layer_points, point_coords, radius, (0, 255, 0, 255), -1) else: # Red circle for exclude cv2.circle(transparent_layer_points, point_coords, radius, (255, 0, 0, 255), -1) # Convert the transparent layer back to an image and composite onto the first frame transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA") # Combine the first frame image with the points layer for the points_map output selected_point_map_img = Image.alpha_composite( first_frame_pil.copy(), transparent_layer_points_pil ) # Prepare points and labels as tensors on CPU for the predictor points = np.array(session_state["input_points"], dtype=np.float32) labels = np.array(session_state["input_labels"], np.int32) points_tensor = torch.tensor(points, dtype=torch.float32, device="cpu").unsqueeze(0) # Add batch dim labels_tensor = torch.tensor(labels, dtype=torch.int32, device="cpu").unsqueeze(0) # Add batch dim # Add new points to the predictor's state and get the mask for the first frame # This call performs segmentation on the current frame (frame_idx=0) using all accumulated points try: _, _, out_mask_logits = predictor.add_new_points( inference_state=session_state["inference_state"], frame_idx=0, # Always segment on the first frame initially obj_id=OBJ_ID, points=points_tensor, labels=labels_tensor, ) # Process logits: detach from graph, move to CPU, apply threshold # out_mask_logits is [batch_size, H, W] (batch_size=1 here) mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Apply threshold and get the single mask tensor [H, W] mask_numpy = mask_tensor.numpy() # Convert to numpy # Get the mask image (RGBA) mask_image_pil = show_mask(mask_numpy, obj_id=OBJ_ID) # show_mask returns RGBA PIL Image # Composite the mask onto the first frame for the output_image first_frame_output_img = Image.alpha_composite(first_frame_pil.copy(), mask_image_pil) except Exception as e: print(f"Error during segmentation on first frame: {e}") # On error, return the points_map but clear the output_image first_frame_output_img = None return selected_point_map_img, first_frame_output_img, session_state def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True): """Helper function to visualize a mask.""" # Ensure mask is a numpy array (and boolean) if isinstance(mask, torch.Tensor): mask = mask.detach().cpu().numpy() # Ensure it's on CPU and converted to numpy mask = mask.astype(bool) # Ensure mask is boolean if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) # RGBA with 0.6 alpha else: cmap = plt.get_cmap("tab10") cmap_idx = 0 if obj_id is None else obj_id % 10 # Use modulo 10 for tab10 colors color = np.array([*cmap(cmap_idx)[:3], 0.6]) # RGBA with 0.6 alpha # Ensure mask has H, W dimensions if mask.ndim == 3: mask = mask.squeeze() # Remove singular dimensions if mask.ndim != 2: print(f"Warning: show_mask received mask with shape {mask.shape}. Expected 2D.") # Create an empty transparent image if mask shape is unexpected if convert_to_image: return Image.fromarray(np.zeros((*mask.shape[:2], 4), dtype=np.uint8), "RGBA") else: return np.zeros((*mask.shape[:2], 4), dtype=np.uint8) h, w = mask.shape # Create an RGBA image from the mask and color # Apply color where mask is True # Need to reshape color to be broadcastable [1, 1, 4] colored_mask = np.zeros((h, w, 4), dtype=np.float32) # Start with fully transparent black colored_mask[mask] = color # Apply color where mask is True # Convert to uint8 [0-255] colored_mask_uint8 = (colored_mask * 255).astype(np.uint8) if convert_to_image: mask_img = Image.fromarray(colored_mask_uint8, "RGBA") return mask_img else: return colored_mask_uint8 # Removed @spaces.GPU decorator def propagate_to_all( video_in, # Keep video_in path to potentially get FPS again if needed session_state, ): """Runs mask propagation through the video and generates the output video.""" print("Starting propagation...") # Ensure state is ready if ( len(session_state["input_points"]) == 0 # Need at least one point or session_state["all_frames"] is None or session_state["inference_state"] is None ): print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.") return ( gr.update(value=None, visible=False), # Hide output video on error session_state, ) # run propagation throughout the video and collect the results # The generator yields (frame_idx, obj_ids, mask_logits) video_segments = {} try: for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( session_state["inference_state"] ): # Process logits: detach from graph, move to CPU, convert to numpy boolean mask # Ensure tensor is on CPU before converting to numpy video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i].detach().cpu() > 0.0).numpy() for i, out_obj_id in enumerate(out_obj_ids) } # Optional: print progress # print(f"Processed frame {out_frame_idx+1}/{len(session_state['all_frames'])}") print("Propagation finished.") except Exception as e: print(f"Error during propagation: {e}") return ( gr.update(value=None, visible=False), # Hide output video on error session_state, ) output_frames = [] # Iterate through all original frames to generate output video for out_frame_idx in range(len(session_state["all_frames"])): original_frame_rgb = session_state["all_frames"][out_frame_idx] # Convert original frame to RGBA for compositing transparent_background = Image.fromarray(original_frame_rgb).convert("RGBA") # Check if we have a mask for this frame and object ID if out_frame_idx in video_segments and OBJ_ID in video_segments[out_frame_idx]: current_mask_numpy = video_segments[out_frame_idx][OBJ_ID] # Get the mask image (RGBA) mask_image_pil = show_mask(current_mask_numpy, obj_id=OBJ_ID) # Composite the mask onto the frame output_frame_img_rgba = Image.alpha_composite(transparent_background, mask_image_pil) # Convert back to numpy RGB (moviepy needs RGB or RGBA) output_frame_np = np.array(output_frame_img_rgba.convert("RGB")) else: # If no mask for this frame/object, just use the original frame (converted to RGB) # Note: all_frames are already RGB numpy arrays, so just use them directly. # print(f"Warning: No mask found for frame {out_frame_idx} and object {OBJ_ID}. Using original frame.") output_frame_np = original_frame_rgb # Already RGB numpy array output_frames.append(output_frame_np) # Define output path in a temporary directory unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness final_vid_filename = f"output_video_{unique_id}.mp4" # Use os.path.join for cross-platform compatibility final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename) print(f"Output video path: {final_vid_output_path}") # Create a video clip from the image sequence # Get original FPS or default original_fps = get_video_fps(video_in) # Re-get FPS from the input file path fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero print(f"Creating output video with FPS: {fps}") # Check if there are frames to process if not output_frames: print("No output frames generated.") return ( gr.update(value=None, visible=False), # Hide output video session_state, ) # Create ImageSequenceClip from the list of numpy arrays try: clip = ImageSequenceClip(output_frames, fps=fps) except Exception as e: print(f"Error creating ImageSequenceClip: {e}") return ( gr.update(value=None, visible=False), # Hide output video on error session_state, ) # Write the result to a file. Use 'libx264' codec for broad compatibility. # `preset` and `threads` for CPU optimization. # `logger=None` prevents moviepy from printing progress to stdout/stderr, which can clutter the Gradio logs. try: print(f"Writing video file with codec='libx264', fps={fps}, preset='medium', threads='auto'") clip.write_videofile( final_vid_output_path, codec="libx264", fps=fps, # Ensure correct FPS is used during writing preset="medium", # CPU optimization: 'fast', 'faster', 'veryfast' are options for speed threads="auto", # CPU optimization: Use multiple cores logger=None # Suppress moviepy output ) print("Video writing complete.") # Return the path and make the video player visible return ( gr.update(value=final_vid_output_path, visible=True), session_state, ) except Exception as e: print(f"Error writing video file: {e}") # Clean up potentially created partial file if os.path.exists(final_vid_output_path): try: os.remove(final_vid_output_path) print(f"Removed partial video file: {final_vid_output_path}") except Exception as clean_e: print(f"Error removing partial file: {clean_e}") # Return None if writing fails return ( gr.update(value=None, visible=False), session_state, ) def update_output_video_visibility(): """Simply returns a Gradio update to make the output video visible.""" return gr.update(visible=True) with gr.Blocks() as demo: # Session state dictionary to hold video frames, points, labels, and predictor state session_state = gr.State( { "first_frame": None, # numpy array (RGB) "all_frames": None, # list of numpy arrays (RGB) "input_points": [], # list of (x, y) tuples/lists "input_labels": [], # list of 1s and 0s "inference_state": None, # EdgeTAM predictor state object "video_path": None, # Store the input video path } ) with gr.Column(): # Title gr.Markdown(title) with gr.Row(): with gr.Column(): # Instructions gr.Markdown(description_p) with gr.Accordion("Input Video", open=True) as video_in_drawer: video_in = gr.Video(label="Input Video", format="mp4") # Will hold the video file path with gr.Row(): point_type = gr.Radio( label="point type", choices=["include", "exclude"], value="include", scale=2, interactive=True, # Make interactive ) # Buttons are initially disabled until a video is loaded propagate_btn = gr.Button("Track", scale=1, variant="primary", interactive=False) clear_points_btn = gr.Button("Clear Points", scale=1, interactive=False) reset_btn = gr.Button("Reset", scale=1, interactive=False) # points_map is where users click to add points. Needs to be interactive. # Shows the first frame with points drawn on it. points_map = gr.Image( label="Frame with Point Prompt", type="numpy", interactive=True, # Make interactive to capture clicks height=400, # Set a fixed height for better UI width="auto", # Let width adjust show_share_button=False, show_download_button=False, # show_label=False # Can hide label if space is tight ) with gr.Column(): gr.Markdown("# Try some of the examples below ⬇️") gr.Examples( examples=examples, inputs=[video_in], examples_per_page=8, cache_examples=False, # Do not cache processed examples, as state is involved ) # Add padding/space # gr.Markdown("
") # output_image shows the segmentation mask prediction on the *first* frame output_image = gr.Image( label="Reference Mask (First Frame)", type="numpy", interactive=False, # Not interactive, just displays the mask height=400, # Match height of points_map width="auto", # Let width adjust show_share_button=False, show_download_button=False, # show_label=False # Can hide label ) # output_video shows the final tracking result output_video = gr.Video(visible=False, label="Tracking Result") # --- Event Handlers --- # When a new video file is uploaded via the file browser video_in.upload( fn=preprocess_video_in, inputs=[video_in, session_state], outputs=[ video_in_drawer, # Close accordion points_map, # Show first frame in points_map output_image, # Clear output image output_video, # Hide output video propagate_btn, # Enable Track button clear_points_btn,# Enable Clear Points button reset_btn, # Enable Reset button session_state, # Update session state ], queue=False, # Process immediately ) # When an example video is selected (change event) video_in.change( fn=preprocess_video_in, inputs=[video_in, session_state], outputs=[ video_in_drawer, # Close accordion points_map, # Show first frame in points_map output_image, # Clear output image output_video, # Hide output video propagate_btn, # Enable Track button clear_points_btn,# Enable Clear Points button reset_btn, # Enable Reset button session_state, # Update session state ], queue=False, # Process immediately ) # Triggered when a user clicks on the points_map image points_map.select( fn=segment_with_points, inputs=[ point_type, # "include" or "exclude" radio button value session_state, # Pass session state ], outputs=[ points_map, # Updated image with points drawn output_image, # Updated image with first frame segmentation mask session_state, # Updated session state (points/labels added) ], queue=False, # Process clicks immediately ) # Button to clear all selected points and reset the first frame mask clear_points_btn.click( fn=clear_points, inputs=[session_state], # Pass session state outputs=[ points_map, # points_map shows original first frame without points output_image, # output_image cleared (or shows original first frame without mask) output_video, # Hide output video session_state, # Updated session state (points/labels cleared, inference state reset) ], queue=False, # Process immediately ) # Button to reset the entire demo state and UI reset_btn.click( fn=reset, inputs=[session_state], # Pass session state outputs=[ video_in, # Clear video input video_in_drawer, # Open video accordion points_map, # Clear points_map output_image, # Clear output_image output_video, # Hide output_video propagate_btn, # Disable buttons clear_points_btn,# Disable buttons reset_btn, # Disable buttons session_state, # Reset session state ], queue=False, # Process immediately ) # Button to start mask propagation through the video propagate_btn.click( fn=update_output_video_visibility, # First, make the output video player visible inputs=[], outputs=[output_video], queue=False, # Process this UI update immediately ).then( # Then, run the propagation function fn=propagate_to_all, inputs=[ video_in, # Get the input video path session_state, # Pass session state (contains frames, points, inference_state) ], outputs=[ output_video, # Update output video player with result session_state, # Update session state (currently, propagate doesn't modify state much, but good practice) ], # CPU Optimization: Limit concurrency to 1 to prevent resource exhaustion. # Queue=True ensures requests wait if another is processing. concurrency_limit=1, queue=True, ) # Launch the Gradio demo demo.queue() # Enable queuing for sequential processing under concurrency limits print("Gradio demo starting...") demo.launch() print("Gradio demo launched.")