Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import copy | |
| import os | |
| from datetime import datetime | |
| import gradio as gr | |
| os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1" | |
| import tempfile | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from moviepy.editor import ImageSequenceClip | |
| from PIL import Image | |
| from sam2.build_sam import build_sam2_video_predictor | |
| # Description | |
| title = "<center><strong><font size='8'>EdgeTAM<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>" | |
| description_p = """# Instructions | |
| <ol> | |
| <li> Upload one video or click one example video</li> | |
| <li> Click 'include' point type, select the object to segment and track</li> | |
| <li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li> | |
| <li> Click the 'Track' button to obtain the masked video </li> | |
| </ol> | |
| """ | |
| # examples | |
| examples = [ | |
| ["examples/trimmed/01_dog.mp4"], | |
| ["examples/trimmed/02_cups.mp4"], | |
| ["examples/trimmed/03_blocks.mp4"], | |
| ["examples/trimmed/04_coffee.mp4"], | |
| ["examples/trimmed/05_default_juggle.mp4"], | |
| ["examples/trimmed/01_breakdancer.mp4"], | |
| ["examples/trimmed/02_hummingbird.mp4"], | |
| ["examples/trimmed/03_skateboarder.mp4"], | |
| ["examples/trimmed/04_octopus.mp4"], | |
| ["examples/trimmed/05_landing_dog_soccer.mp4"], | |
| ["examples/trimmed/06_pingpong.mp4"], | |
| ["examples/trimmed/07_snowboarder.mp4"], | |
| ["examples/trimmed/08_driving.mp4"], | |
| ["examples/trimmed/09_birdcartoon.mp4"], | |
| ["examples/trimmed/10_cloth_magic.mp4"], | |
| ["examples/trimmed/11_polevault.mp4"], | |
| ["examples/trimmed/12_hideandseek.mp4"], | |
| ["examples/trimmed/13_butterfly.mp4"], | |
| ["examples/trimmed/14_social_dog_training.mp4"], | |
| ["examples/trimmed/15_cricket.mp4"], | |
| ["examples/trimmed/16_robotarm.mp4"], | |
| ["examples/trimmed/17_childrendancing.mp4"], | |
| ["examples/trimmed/18_threedogs.mp4"], | |
| ["examples/trimmed/19_cyclist.mp4"], | |
| ["examples/trimmed/20_doughkneading.mp4"], | |
| ["examples/trimmed/21_biker.mp4"], | |
| ["examples/trimmed/22_dogskateboarder.mp4"], | |
| ["examples/trimmed/23_racecar.mp4"], | |
| ["examples/trimmed/24_clownfish.mp4"], | |
| ] | |
| OBJ_ID = 0 | |
| sam2_checkpoint = "checkpoints/edgetam.pt" | |
| model_cfg = "edgetam.yaml" | |
| predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu") | |
| def get_video_fps(video_path): | |
| # Open the video file | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| print("Error: Could not open video.") | |
| return None | |
| # Get the FPS of the video | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| return fps | |
| def reset_state(inference_state): | |
| for v in inference_state["point_inputs_per_obj"].values(): | |
| v.clear() | |
| for v in inference_state["mask_inputs_per_obj"].values(): | |
| v.clear() | |
| for v in inference_state["output_dict_per_obj"].values(): | |
| v["cond_frame_outputs"].clear() | |
| v["non_cond_frame_outputs"].clear() | |
| for v in inference_state["temp_output_dict_per_obj"].values(): | |
| v["cond_frame_outputs"].clear() | |
| v["non_cond_frame_outputs"].clear() | |
| inference_state["output_dict"]["cond_frame_outputs"].clear() | |
| inference_state["output_dict"]["non_cond_frame_outputs"].clear() | |
| inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() | |
| inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() | |
| inference_state["tracking_has_started"] = False | |
| inference_state["frames_already_tracked"].clear() | |
| inference_state["obj_id_to_idx"].clear() | |
| inference_state["obj_idx_to_id"].clear() | |
| inference_state["obj_ids"].clear() | |
| inference_state["point_inputs_per_obj"].clear() | |
| inference_state["mask_inputs_per_obj"].clear() | |
| inference_state["output_dict_per_obj"].clear() | |
| inference_state["temp_output_dict_per_obj"].clear() | |
| return inference_state | |
| def reset( | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ): | |
| first_frame = None | |
| all_frames = None | |
| input_points = [] | |
| input_labels = [] | |
| inference_state = None | |
| return ( | |
| None, | |
| gr.update(open=True), | |
| None, | |
| None, | |
| gr.update(value=None, visible=False), | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ) | |
| def clear_points( | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ): | |
| input_points = [] | |
| input_labels = [] | |
| if inference_state and inference_state["tracking_has_started"]: | |
| inference_state = reset_state(inference_state) | |
| return ( | |
| first_frame, | |
| None, | |
| gr.update(value=None, visible=False), | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ) | |
| def preprocess_video_in( | |
| video_path, | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ): | |
| if video_path is None: | |
| return ( | |
| gr.update(open=True), # video_in_drawer | |
| None, # points_map | |
| None, # output_image | |
| gr.update(value=None, visible=False), # output_video | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ) | |
| # Read the first frame | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| print("Error: Could not open video.") | |
| return ( | |
| gr.update(open=True), # video_in_drawer | |
| None, # points_map | |
| None, # output_image | |
| gr.update(value=None, visible=False), # output_video | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ) | |
| frame_number = 0 | |
| _first_frame = None | |
| all_frames = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame = np.array(frame) | |
| # Store the first frame | |
| if frame_number == 0: | |
| _first_frame = frame | |
| all_frames.append(frame) | |
| frame_number += 1 | |
| cap.release() | |
| first_frame = copy.deepcopy(_first_frame) | |
| input_points = [] | |
| input_labels = [] | |
| predictor.to("cpu") | |
| inference_state = predictor.init_state( | |
| offload_video_to_cpu=True, | |
| offload_state_to_cpu=True, | |
| video_path=video_path, | |
| ) | |
| return [ | |
| gr.update(open=False), # video_in_drawer | |
| first_frame, # points_map | |
| None, # output_image | |
| gr.update(value=None, visible=False), # output_video | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ] | |
| def segment_with_points( | |
| point_type, | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| evt: gr.SelectData, | |
| ): | |
| predictor.to("cpu") | |
| inference_state["device"] = predictor.device | |
| input_points.append(evt.index) | |
| print(f"TRACKING INPUT POINT: {input_points}") | |
| if point_type == "include": | |
| input_labels.append(1) | |
| elif point_type == "exclude": | |
| input_labels.append(0) | |
| print(f"TRACKING INPUT LABEL: {input_labels}") | |
| # Open the image and get its dimensions | |
| transparent_background = Image.fromarray(first_frame).convert("RGBA") | |
| w, h = transparent_background.size | |
| # Define the circle radius as a fraction of the smaller dimension | |
| fraction = 0.01 # You can adjust this value as needed | |
| radius = int(fraction * min(w, h)) | |
| # Create a transparent layer to draw on | |
| transparent_layer = np.zeros((h, w, 4), dtype=np.uint8) | |
| for index, track in enumerate(input_points): | |
| if input_labels[index] == 1: | |
| cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1) | |
| else: | |
| cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1) | |
| # Convert the transparent layer back to an image | |
| transparent_layer = Image.fromarray(transparent_layer, "RGBA") | |
| selected_point_map = Image.alpha_composite( | |
| transparent_background, transparent_layer | |
| ) | |
| # Let's add a positive click at (x, y) = (210, 350) to get started | |
| points = np.array(input_points, dtype=np.float32) | |
| # for labels, `1` means positive click and `0` means negative click | |
| labels = np.array(input_labels, dtype=np.int32) | |
| _, _, out_mask_logits = predictor.add_new_points( | |
| inference_state=inference_state, | |
| frame_idx=0, | |
| obj_id=OBJ_ID, | |
| points=points, | |
| labels=labels, | |
| ) | |
| mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy()) | |
| first_frame_output = Image.alpha_composite(transparent_background, mask_image) | |
| torch.cuda.empty_cache() | |
| return ( | |
| selected_point_map, | |
| first_frame_output, | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ) | |
| def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True): | |
| if random_color: | |
| color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
| else: | |
| cmap = plt.get_cmap("tab10") | |
| cmap_idx = 0 if obj_id is None else obj_id | |
| color = np.array([*cmap(cmap_idx)[:3], 0.6]) | |
| h, w = mask.shape[-2:] | |
| mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
| mask = (mask * 255).astype(np.uint8) | |
| if convert_to_image: | |
| mask = Image.fromarray(mask, "RGBA") | |
| return mask | |
| def propagate_to_all( | |
| video_in, | |
| all_frames, | |
| input_points, | |
| inference_state, | |
| ): | |
| if torch.cuda.get_device_properties(0).major >= 8: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| predictor.to("cuda") | |
| inference_state["device"] = predictor.device | |
| if len(input_points) == 0 or video_in is None or inference_state is None: | |
| return None | |
| # run propagation throughout the video and collect the results in a dict | |
| video_segments = ( | |
| {} | |
| ) # video_segments contains the per-frame segmentation results | |
| print("starting propagate_in_video") | |
| for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( | |
| inference_state | |
| ): | |
| video_segments[out_frame_idx] = { | |
| out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() | |
| for i, out_obj_id in enumerate(out_obj_ids) | |
| } | |
| # obtain the segmentation results every few frames | |
| vis_frame_stride = 1 | |
| output_frames = [] | |
| for out_frame_idx in range(0, len(video_segments), vis_frame_stride): | |
| transparent_background = Image.fromarray(all_frames[out_frame_idx]).convert( | |
| "RGBA" | |
| ) | |
| out_mask = video_segments[out_frame_idx][OBJ_ID] | |
| mask_image = show_mask(out_mask) | |
| output_frame = Image.alpha_composite(transparent_background, mask_image) | |
| output_frame = np.array(output_frame) | |
| output_frames.append(output_frame) | |
| torch.cuda.empty_cache() | |
| # Create a video clip from the image sequence | |
| original_fps = get_video_fps(video_in) | |
| fps = original_fps # Frames per second | |
| clip = ImageSequenceClip(output_frames, fps=fps) | |
| # Write the result to a file | |
| unique_id = datetime.now().strftime("%Y%m%d%H%M%S") | |
| final_vid_output_path = f"output_video_{unique_id}.mp4" | |
| final_vid_output_path = os.path.join( | |
| tempfile.gettempdir(), final_vid_output_path | |
| ) | |
| # Write the result to a file | |
| clip.write_videofile(final_vid_output_path, codec="libx264") | |
| return gr.update(value=final_vid_output_path) | |
| def update_ui(): | |
| return gr.update(visible=True) | |
| with gr.Blocks() as demo: | |
| first_frame = gr.State() | |
| all_frames = gr.State() | |
| input_points = gr.State([]) | |
| input_labels = gr.State([]) | |
| inference_state = gr.State() | |
| with gr.Column(): | |
| # Title | |
| gr.Markdown(title) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Instructions | |
| gr.Markdown(description_p) | |
| with gr.Accordion("Input Video", open=True) as video_in_drawer: | |
| video_in = gr.Video(label="Input Video", format="mp4") | |
| with gr.Row(): | |
| point_type = gr.Radio( | |
| label="point type", | |
| choices=["include", "exclude"], | |
| value="include", | |
| scale=2, | |
| ) | |
| propagate_btn = gr.Button("Track", scale=1, variant="primary") | |
| clear_points_btn = gr.Button("Clear Points", scale=1) | |
| reset_btn = gr.Button("Reset", scale=1) | |
| points_map = gr.Image( | |
| label="Frame with Point Prompt", type="numpy", interactive=False | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("# Try some of the examples below ⬇️") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[ | |
| video_in, | |
| ], | |
| examples_per_page=8, | |
| ) | |
| gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") | |
| gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") | |
| gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") | |
| output_image = gr.Image(label="Reference Mask") | |
| output_video = gr.Video(visible=False) | |
| # When new video is uploaded | |
| video_in.upload( | |
| fn=preprocess_video_in, | |
| inputs=[ | |
| video_in, | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ], | |
| outputs=[ | |
| video_in_drawer, # Accordion to hide uploaded video player | |
| points_map, # Image component where we add new tracking points | |
| output_image, | |
| output_video, | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ], | |
| queue=False, | |
| ) | |
| video_in.change( | |
| fn=preprocess_video_in, | |
| inputs=[ | |
| video_in, | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ], | |
| outputs=[ | |
| video_in_drawer, # Accordion to hide uploaded video player | |
| points_map, # Image component where we add new tracking points | |
| output_image, | |
| output_video, | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ], | |
| queue=False, | |
| ) | |
| # triggered when we click on image to add new points | |
| points_map.select( | |
| fn=segment_with_points, | |
| inputs=[ | |
| point_type, # "include" or "exclude" | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ], | |
| outputs=[ | |
| points_map, # updated image with points | |
| output_image, | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ], | |
| queue=False, | |
| ) | |
| # Clear every points clicked and added to the map | |
| clear_points_btn.click( | |
| fn=clear_points, | |
| inputs=[ | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ], | |
| outputs=[ | |
| points_map, | |
| output_image, | |
| output_video, | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ], | |
| queue=False, | |
| ) | |
| reset_btn.click( | |
| fn=reset, | |
| inputs=[ | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ], | |
| outputs=[ | |
| video_in, | |
| video_in_drawer, | |
| points_map, | |
| output_image, | |
| output_video, | |
| first_frame, | |
| all_frames, | |
| input_points, | |
| input_labels, | |
| inference_state, | |
| ], | |
| queue=False, | |
| ) | |
| propagate_btn.click( | |
| fn=update_ui, | |
| inputs=[], | |
| outputs=output_video, | |
| queue=False, | |
| ).then( | |
| fn=propagate_to_all, | |
| inputs=[ | |
| video_in, | |
| all_frames, | |
| input_points, | |
| inference_state, | |
| ], | |
| outputs=[ | |
| output_video, | |
| ], | |
| concurrency_limit=10, | |
| queue=False, | |
| ) | |
| # demo.queue() | |
| demo.launch() | |