Spaces:
bla
/
Runtime error

EdgeTAM / app.py
bla's picture
Update app.py
e508568 verified
raw
history blame
30.5 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
import os
from datetime import datetime
import gradio as gr
# This line might be related to GPU, kept from original
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7"
import tempfile
import cv2
import matplotlib.pyplot as plt
# spaces import and decorators are for Hugging Face Spaces GPU allocation,
# if running locally without spaces, these can be removed or will be ignored.
import spaces
import numpy as np
import torch
from moviepy.editor import ImageSequenceClip
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor
# Description
title = "<center><strong><font size='8'>EdgeTAM<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>"
description_p = """# Instructions
<ol>
<li> Upload one video or click one example video</li>
<li> Click 'include' point type, select the object to segment and track</li>
<li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li>
<li> Click the 'Track' button to obtain the masked video </li>
</ol>
"""
# examples
examples = [
["examples/01_dog.mp4"],
["examples/02_cups.mp4"],
["examples/03_blocks.mp4"],
["examples/04_coffee.mp4"],
["examples/05_default_juggle.mp4"],
["examples/01_breakdancer.mp4"],
["examples/02_hummingbird.mp4"],
["examples/03_skateboarder.mp4"],
["examples/04_octopus.mp4"],
["examples/05_landing_dog_soccer.mp4"],
["examples/06_pingpong.mp4"],
["examples/07_snowboarder.mp4"],
["examples/08_driving.mp4"],
["examples/09_birdcartoon.mp4"],
["examples/10_cloth_magic.mp4"],
["examples/11_polevault.mp4"],
["examples/12_hideandseek.mp4"],
["examples/13_butterfly.mp4"],
["examples/14_social_dog_training.mp4"],
["examples/15_cricket.mp4"],
["examples/16_robotarm.mp4"],
["examples/17_childrendancing.mp4"],
["examples/18_threedogs.mp4"],
["examples/19_cyclist.mp4"],
["examples/20_doughkneading.mp4"],
["examples/21_biker.mp4"],
["examples/22_dogskateboarder.mp4"],
["examples/23_racecar.mp4"],
["examples/24_clownfish.mp4"],
]
OBJ_ID = 0
sam2_checkpoint = "checkpoints/edgetam.pt"
model_cfg = "edgetam.yaml"
# Model built for CPU but immediately moved to CUDA in original code
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
# *** Original code moves to CUDA ***
predictor.to("cuda")
print("predictor loaded on CUDA")
# use bfloat16 for the entire demo - Original code uses CUDA bfloat16
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
# Original CUDA settings
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
elif not torch.cuda.is_available():
print("Warning: CUDA not available. The original code is configured for GPU.")
# Note: Without a GPU, the .to("cuda") calls will likely cause errors.
def get_video_fps(video_path):
# Open the video file
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print("Error: Could not open video.")
return None
# Get the FPS of the video
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release() # Release the capture object
return fps
def reset(session_state):
"""Resets the UI and session state."""
print("Resetting demo.")
session_state["input_points"] = []
session_state["input_labels"] = []
# Reset the predictor state if it exists
if session_state["inference_state"] is not None:
# Assuming predictor.reset_state handles None or invalid states gracefully
# Or you might need to explicitly pass the state object if required
try:
predictor.reset_state(session_state["inference_state"])
# Explicitly delete or re-init the state object if a full reset is intended
# This depends on how predictor.reset_state works
# session_state["inference_state"] = None # Example if reset_state doesn't fully clear
except Exception as e:
print(f"Error resetting predictor state: {e}")
# If reset fails, perhaps force-clear the state object
session_state["inference_state"] = None
session_state["first_frame"] = None
session_state["all_frames"] = None
session_state["inference_state"] = None # Ensure state is None after a full reset
# Also reset video path if stored
session_state["video_path"] = None
# Resetting UI components
return (
None, # video_in (clears the video player)
gr.update(open=True), # video_in_drawer (opens accordion)
None, # points_map (clears the image)
None, # output_image (clears the image)
gr.update(value=None, visible=False), # output_video (hides and clears)
session_state, # return updated session state
)
def clear_points(session_state):
"""Clears selected points and resets segmentation on the first frame."""
print("Clearing points.")
session_state["input_points"] = []
session_state["input_labels"] = []
# Reset the predictor state to clear internal masks/features
# This typically doesn't remove the video context, just the mask predictions
if session_state["inference_state"] is not None:
try:
# Assuming reset_state handles clearing current masks/features
predictor.reset_state(session_state["inference_state"])
print("Predictor state reset for clearing points.")
# If you need to re-initialize the state for the *same* video after clearing points,
# you might need to call predictor.init_state again here, using the stored video_path.
# session_state["inference_state"] = predictor.init_state(video_path=session_state["video_path"], device="cuda") # Or device="cpu" if modified earlier
except Exception as e:
print(f"Error resetting predictor state during clear_points: {e}")
# If reset fails, this might leave old masks. Depending on SAM2's behavior,
# you might need a more aggressive state clear or re-initialization.
# Return the original first frame image for points_map and clear the output_image
first_frame_img = session_state["first_frame"] if session_state["first_frame"] is not None else None
return (
first_frame_img, # points_map shows original first frame (no points yet)
None, # output_image cleared (no mask)
gr.update(value=None, visible=False), # output_video hidden
session_state, # return updated session state
)
# Added @spaces.GPU decorator back as it was in the original code
@spaces.GPU
def preprocess_video_in(video_path, session_state):
"""Loads video frames and initializes the predictor state."""
print(f"Processing video: {video_path}")
if video_path is None or not os.path.exists(video_path):
print("No video path provided or file not found.")
# Reset state and UI elements if input is invalid
# Need to return updates for the buttons as well
return (
gr.update(open=True), None, None, gr.update(value=None, visible=False),
gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
{ # Reset session state
"first_frame": None, "all_frames": None, "input_points": [],
"input_labels": [], "inference_state": None, "video_path": None,
}
)
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open video file {video_path}.")
return (
gr.update(open=True), None, None, gr.update(value=None, visible=False),
gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
{ # Reset session state
"first_frame": None, "all_frames": None, "input_points": [],
"input_labels": [], "inference_state": None, "video_path": None,
}
)
first_frame = None
all_frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
all_frames.append(frame)
if first_frame is None:
first_frame = frame
cap.release()
if not all_frames:
print(f"Error: No frames read from video file {video_path}.")
return (
gr.update(open=True), None, None, gr.update(value=None, visible=False),
gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False),
{ # Reset session state
"first_frame": None, "all_frames": None, "input_points": [],
"input_labels": [], "inference_state": None, "video_path": None,
}
)
session_state["first_frame"] = copy.deepcopy(first_frame)
session_state["all_frames"] = all_frames
session_state["video_path"] = video_path # Store video path
session_state["input_points"] = []
session_state["input_labels"] = []
# Original code did NOT pass device here. It uses the device the predictor is on.
session_state["inference_state"] = predictor.init_state(video_path=video_path)
print("Video loaded and predictor state initialized.")
# Enable buttons after successful load
return [
gr.update(open=False), # video_in_drawer
first_frame, # points_map
None, # output_image
gr.update(value=None, visible=False), # output_video
gr.update(interactive=True), # propagate_btn
gr.update(interactive=True), # clear_points_btn
gr.update(interactive=True), # reset_btn
session_state, # session_state
]
# Added @spaces.GPU decorator back as it was in the original code
@spaces.GPU
def segment_with_points(
point_type,
session_state,
evt: gr.SelectData,
):
"""Adds a point prompt and performs segmentation on the first frame."""
# Ensure we have state and first frame
if session_state["first_frame"] is None or session_state["inference_state"] is None:
print("Error: Cannot segment. No video loaded or inference state missing.")
# Return current images and state without changes
return (
session_state.get("first_frame"), # points_map (show first frame if exists)
None, # output_image (keep cleared)
session_state,
)
# evt.index is the (x, y) coordinate tuple
click_coords = evt.index
print(f"Clicked at: {click_coords} ({point_type})")
session_state["input_points"].append(click_coords)
if point_type == "include":
session_state["input_labels"].append(1)
elif point_type == "exclude":
session_state["input_labels"].append(0)
# Get the first frame as a PIL image for drawing
first_frame_pil = Image.fromarray(session_state["first_frame"]).convert("RGBA")
w, h = first_frame_pil.size
# Define the circle radius
fraction = 0.01
radius = max(2, int(fraction * min(w, h))) # Ensure minimum radius of 2
# Create a transparent layer to draw points
transparent_layer_points = np.zeros((h, w, 4), dtype=np.uint8)
# Draw points on the transparent layer
for index, track in enumerate(session_state["input_points"]):
# Ensure coordinates are integers for cv2.circle
point_coords = (int(track[0]), int(track[1]))
# Ensure color is RGBA (0-255)
if session_state["input_labels"][index] == 1:
cv2.circle(transparent_layer_points, point_coords, radius, (0, 255, 0, 255), -1) # Green for include
else:
cv2.circle(transparent_layer_points, point_coords, radius, (255, 0, 0, 255), -1) # Red for exclude
# Convert the transparent layer back to an image and composite onto the first frame
transparent_layer_points_pil = Image.fromarray(transparent_layer_points, "RGBA")
# Combine the first frame image with the points layer for the points_map output
# points_map shows the first frame *with the points you added*.
selected_point_map_img = Image.alpha_composite(
first_frame_pil.copy(), transparent_layer_points_pil
)
# Prepare points and labels as tensors on the correct device (CUDA in original code)
points = np.array(session_state["input_points"], dtype=np.float32)
labels = np.array(session_state["input_labels"], np.int32)
# Ensure tensors are on the correct device (CUDA as per original code setup)
device = next(predictor.parameters()).device # Get the device the model is on
points_tensor = torch.tensor(points, dtype=torch.float32, device=device).unsqueeze(0) # Add batch dim
labels_tensor = torch.tensor(labels, dtype=torch.int32, device=device).unsqueeze(0) # Add batch dim
# Add new points to the predictor's state and get the mask for the first frame
# This call performs segmentation on the current frame (frame_idx=0) using all accumulated points
first_frame_output_img = None # Initialize output mask image as None in case of error
try:
# Note: predictor.add_new_points modifies the internal inference_state
_, _, out_mask_logits = predictor.add_new_points(
inference_state=session_state["inference_state"],
frame_idx=0, # Always segment on the first frame initially
obj_id=OBJ_ID,
points=points_tensor,
labels=labels_tensor,
)
# Process logits: detach from graph, move to CPU, apply threshold
# out_mask_logits is a list of tensors [tensor([batch_size, H, W])] for the requested obj_id
# Access the result for the first object (index 0) and the first item in batch (index 0)
mask_tensor = (out_mask_logits[0][0].detach().cpu() > 0.0) # Move to CPU before converting to numpy
mask_numpy = mask_tensor.numpy() # Convert to numpy
# Get the mask image (RGBA)
mask_image_pil = show_mask(mask_numpy, obj_id=OBJ_ID) # show_mask returns RGBA PIL Image
# Composite the mask onto the first frame for the output_image
# output_image shows the first frame *with the segmentation mask result*.
first_frame_output_img = Image.alpha_composite(first_frame_pil.copy(), mask_image_pil)
except Exception as e:
print(f"Error during segmentation on first frame: {e}")
# On error, first_frame_output_img remains None
# Original code clears CUDA cache here
if torch.cuda.is_available():
torch.cuda.empty_cache()
return selected_point_map_img, first_frame_output_img, session_state
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
"""Helper function to visualize a mask."""
# Ensure mask is a numpy array (and boolean)
if isinstance(mask, torch.Tensor):
mask = mask.detach().cpu().numpy() # Ensure it's on CPU and converted to numpy
# Convert potential float/int mask to boolean mask
mask = mask.astype(bool)
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) # RGBA with 0.6 alpha
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id % 10 # Use modulo 10 for tab10 colors
color = np.array([*cmap(cmap_idx)[:3], 0.6]) # RGBA with 0.6 alpha
# Ensure mask has H, W dimensions
if mask.ndim == 3:
mask = mask.squeeze() # Remove singular dimensions like (H, W, 1)
if mask.ndim != 2:
print(f"Warning: show_mask received mask with shape {mask.shape}. Expected 2D.")
# Create an empty transparent image if mask shape is unexpected
h, w = mask.shape[:2] if mask.ndim >= 2 else (100, 100) # Use actual shape if possible, otherwise default
if convert_to_image:
return Image.fromarray(np.zeros((h, w, 4), dtype=np.uint8), "RGBA")
else:
return np.zeros((h, w, 4), dtype=np.uint8)
h, w = mask.shape
# Create an RGBA image from the mask and color
# Apply color where mask is True
# Need to reshape color to be broadcastable [1, 1, 4]
colored_mask = np.zeros((h, w, 4), dtype=np.float32) # Start with fully transparent black
# Apply the color only where the mask is True.
# This directly creates the colored overlay with transparency.
colored_mask[mask] = color
# Convert to uint8 [0-255]
colored_mask_uint8 = (colored_mask * 255).astype(np.uint8)
if convert_to_image:
mask_img = Image.fromarray(colored_mask_uint8, "RGBA")
return mask_img
else:
return colored_mask_uint8
# Added @spaces.GPU decorator back as it was in the original code
@spaces.GPU
def propagate_to_all(
video_in, # Keep video_in path as in original
session_state,
):
"""Runs mask propagation through the video and generates the output video."""
print("Starting propagation...")
# Ensure state is ready
# Using session_state.get("video_path") is safer than video_in directly
current_video_path = session_state.get("video_path")
if (
len(session_state["input_points"]) == 0 # Need at least one point
or session_state["all_frames"] is None
or session_state["inference_state"] is None
or current_video_path is None # Ensure we have the original video path
):
print("Error: Cannot propagate. No points selected, video not loaded, or inference state missing.")
return (
gr.update(value=None, visible=False), # Hide output video on error
session_state,
)
# run propagation throughout the video and collect the results
video_segments = {}
try:
# This loop performs the core tracking prediction frame by frame
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
session_state["inference_state"]
):
# Process logits: detach from graph, move to CPU, convert to numpy boolean mask
# Ensure tensor is on CPU before converting to numpy
video_segments[out_frame_idx] = {
# out_mask_logits is a list of tensors (one per object tracked in this frame)
# Each tensor is [batch_size, H, W]. Batch size is 1 here.
# Access the result for the first object (index i) and the first item in batch (index 0)
out_obj_id: (out_mask_logits[i][0].detach().cpu() > 0.0).numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
# Optional: print progress
# print(f"Processed frame {out_frame_idx+1}/{len(session_state['all_frames'])}")
print("Propagation finished.")
except Exception as e:
print(f"Error during propagation: {e}")
return (
gr.update(value=None, visible=False), # Hide output video on error
session_state,
)
output_frames = []
# Iterate through all original frames to generate output video
total_frames = len(session_state["all_frames"])
for out_frame_idx in range(total_frames):
original_frame_rgb = session_state["all_frames"][out_frame_idx]
# Convert original frame to RGBA for compositing
transparent_background = Image.fromarray(original_frame_rgb).convert("RGBA")
# Check if we have a mask for this frame and object ID
if out_frame_idx in video_segments and OBJ_ID in video_segments[out_frame_idx]:
current_mask_numpy = video_segments[out_frame_idx][OBJ_ID]
# Get the mask image (RGBA)
mask_image_pil = show_mask(current_mask_numpy, obj_id=OBJ_ID)
# Composite the mask onto the frame
output_frame_img_rgba = Image.alpha_composite(transparent_background, mask_image_pil)
# Convert back to numpy RGB (moviepy needs RGB or RGBA)
output_frame_np = np.array(output_frame_img_rgba.convert("RGB"))
else:
# If no mask for this frame/object, just use the original frame (converted to RGB)
# Note: all_frames are already RGB numpy arrays, so just use them directly.
# print(f"Warning: No mask found for frame {out_frame_idx} and object {OBJ_ID}. Using original frame.")
output_frame_np = original_frame_rgb # Already RGB numpy array
output_frames.append(output_frame_np)
# Original code clears CUDA cache here
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Define output path in a temporary directory
unique_id = datetime.now().strftime("%Y%m%d%H%M%S%f") # Use microseconds for more uniqueness
final_vid_filename = f"output_video_{unique_id}.mp4"
final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_filename)
print(f"Output video path: {final_vid_output_path}")
# Create a video clip from the image sequence
# Get original FPS from the stored video path
original_fps = get_video_fps(current_video_path)
fps = original_fps if original_fps is not None and original_fps > 0 else 30 # Default to 30 if detection fails or is zero
print(f"Creating output video with FPS: {fps}")
# Check if there are frames to process
if not output_frames:
print("No output frames generated.")
return (
gr.update(value=None, visible=False), # Hide output video
session_state,
)
# Create ImageSequenceClip from the list of numpy arrays
try:
clip = ImageSequenceClip(output_frames, fps=fps)
except Exception as e:
print(f"Error creating ImageSequenceClip: {e}")
return (
gr.update(value=None, visible=False), # Hide output video on error
session_state,
)
# Write the result to a file. Use 'libx264' codec for broad compatibility.
try:
print(f"Writing video file with codec='libx264', fps={fps}")
# Added basic moviepy writing parameters back, similar to original intent
clip.write_videofile(final_vid_output_path, codec="libx264", fps=fps)
print("Video writing complete.")
# Return the path and make the video player visible
return (
gr.update(value=final_vid_output_path, visible=True),
session_state,
)
except Exception as e:
print(f"Error writing video file: {e}")
# Clean up potentially created partial file
if os.path.exists(final_vid_output_path):
try:
os.remove(final_vid_output_path)
print(f"Removed partial video file: {final_vid_output_path}")
except Exception as clean_e:
print(f"Error removing partial file: {clean_e}")
# Return None if writing fails
return (
gr.update(value=None, visible=False),
session_state,
)
def update_ui():
"""Simply returns a Gradio update to make the output video visible."""
return gr.update(visible=True)
with gr.Blocks() as demo:
# Session state dictionary to hold video frames, points, labels, and predictor state
session_state = gr.State(
{
"first_frame": None, # numpy array (RGB)
"all_frames": None, # list of numpy arrays (RGB)
"input_points": [], # list of (x, y) tuples/lists
"input_labels": [], # list of 1s and 0s
"inference_state": None, # EdgeTAM predictor state object
"video_path": None, # Store the input video path
}
)
with gr.Column():
# Title
gr.Markdown(title)
with gr.Row():
with gr.Column():
# Instructions
gr.Markdown(description_p)
with gr.Accordion("Input Video", open=True) as video_in_drawer:
video_in = gr.Video(label="Input Video", format="mp4") # Will hold the video file path
with gr.Row():
point_type = gr.Radio(
label="point type",
choices=["include", "exclude"],
value="include",
scale=2,
interactive=True, # Make interactive
)
# Buttons are initially disabled until a video is loaded
propagate_btn = gr.Button("Track", scale=1, variant="primary", interactive=False)
clear_points_btn = gr.Button("Clear Points", scale=1, interactive=False)
reset_btn = gr.Button("Reset", scale=1, interactive=False)
# points_map is where users click to add points. Needs to be interactive.
# Shows the first frame with points drawn on it.
points_map = gr.Image(
label="Click on the First Frame to Add Points", # Clearer label
type="numpy",
interactive=True, # <--- THIS WAS CHANGED FROM False TO True
height=400, # Set a fixed height for better UI
width="auto", # Let width adjust
show_share_button=False,
show_download_button=False,
)
with gr.Column():
gr.Markdown("# Try some of the examples below ⬇️")
gr.Examples(
examples=examples,
inputs=[video_in],
examples_per_page=8,
cache_examples=False, # Do not cache processed examples, as state is involved
)
# Add padding/space - removed extra lines as they take up a lot of space
# gr.Markdown("<br>")
# output_image shows the segmentation mask prediction on the *first* frame
output_image = gr.Image(
label="Segmentation Mask on First Frame", # Clearer label
type="numpy",
interactive=False, # Not interactive, just displays the mask
height=400, # Match height of points_map
width="auto", # Let width adjust
show_share_button=False,
show_download_button=False,
)
# output_video shows the final tracking result
output_video = gr.Video(visible=False, label="Tracking Result")
# --- Event Handlers ---
# When a new video file is uploaded via the file browser
# Added postprocess to update button interactivity based on whether video loaded
video_in.upload(
fn=preprocess_video_in,
inputs=[video_in, session_state],
outputs=[
video_in_drawer, points_map, output_image, output_video,
propagate_btn, clear_points_btn, reset_btn, session_state,
],
queue=False, # Process immediately
)
# When an example video is selected (change event)
# Added postprocess to update button interactivity
video_in.change(
fn=preprocess_video_in,
inputs=[video_in, session_state],
outputs=[
video_in_drawer, points_map, output_image, output_video,
propagate_btn, clear_points_btn, reset_btn, session_state,
],
queue=False, # Process immediately
)
# Triggered when a user clicks on the points_map image
points_map.select(
fn=segment_with_points,
inputs=[
point_type, # "include" or "exclude" radio button value
session_state, # Pass session state
],
outputs=[
points_map, # Updated image with points drawn
output_image, # Updated image with first frame segmentation mask
session_state, # Updated session state (points/labels added)
],
queue=False, # Process clicks immediately
)
# Button to clear all selected points and reset the first frame mask
clear_points_btn.click(
fn=clear_points,
inputs=[session_state], # Pass session state
outputs=[
points_map, # points_map shows original first frame without points
output_image, # output_image cleared (or shows original first frame without mask)
output_video, # Hide output video
session_state, # Updated session state (points/labels cleared, inference state reset)
],
queue=False, # Process immediately
)
# Button to reset the entire demo state and UI
reset_btn.click(
fn=reset,
inputs=[session_state], # Pass session state
outputs=[
video_in, video_in_drawer, points_map, output_image, output_video,
propagate_btn, clear_points_btn, reset_btn, session_state,
],
queue=False, # Process immediately
)
# Button to start mask propagation through the video
propagate_btn.click(
fn=update_output_video_visibility, # First, make the output video player visible
inputs=[],
outputs=[output_video],
queue=False, # Process this UI update immediately
).then( # Then, run the propagation function
fn=propagate_to_all,
inputs=[
video_in, # Get the input video path (can also get from session_state["video_path"])
session_state, # Pass session state (contains frames, points, inference_state, video_path)
],
outputs=[
output_video, # Update output video player with result
session_state, # Update session state
],
# concurrency_limit from original code (may need adjustment based on your hardware/GPU)
concurrency_limit=10,
queue=False, # queue from original code
)
# Launch the Gradio demo
demo.queue() # Enable queuing
print("Gradio demo starting...")
demo.launch()
print("Gradio demo launched.")