|
|
|
|
|
import os |
|
import copy |
|
import tempfile |
|
from datetime import datetime |
|
import gc |
|
|
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import torch |
|
import gradio as gr |
|
from moviepy.editor import ImageSequenceClip |
|
|
|
from sam2.build_sam import build_sam2_video_predictor |
|
|
|
|
|
os.environ.pop("TORCH_CUDNN_SDPA_ENABLED", None) |
|
|
|
|
|
sam2_checkpoint = "checkpoints/edgetam.pt" |
|
model_cfg = "edgetam.yaml" |
|
examples = [[f"examples/{vid}"] for vid in ["01_dog.mp4", "02_cups.mp4", "03_blocks.mp4", "04_coffee.mp4", "05_default_juggle.mp4"]] |
|
OBJ_ID = 0 |
|
|
|
|
|
if os.path.exists(sam2_checkpoint) and os.path.exists(model_cfg): |
|
try: |
|
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu") |
|
except Exception as e: |
|
print("Error loading predictor:", e) |
|
predictor = None |
|
else: |
|
print("Model files missing.") |
|
predictor = None |
|
|
|
def get_fps(video_path): |
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): return 30.0 |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
cap.release() |
|
return fps |
|
|
|
def reset(session): |
|
if session["inference_state"]: |
|
predictor.reset_state(session["inference_state"]) |
|
session.update({"input_points": [], "input_labels": [], "first_frame": None, "all_frames": None, "inference_state": None}) |
|
return None, gr.update(open=True), None, None, gr.update(value=None, visible=False), session |
|
|
|
def clear_points(session): |
|
session["input_points"] = [] |
|
session["input_labels"] = [] |
|
if session["inference_state"] and session["inference_state"].get("tracking_has_started"): |
|
predictor.reset_state(session["inference_state"]) |
|
return session["first_frame"], None, gr.update(value=None, visible=False), session |
|
|
|
def preprocess_video(video_path, session): |
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): return gr.update(open=True), None, None, gr.update(value=None, visible=False), session |
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
stride = max(1, total_frames // 300) |
|
frames, first_frame = [], None |
|
|
|
w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
target_w = 640 |
|
scale = target_w / w if w > target_w else 1.0 |
|
|
|
frame_id = 0 |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: break |
|
if frame_id % stride == 0: |
|
if scale < 1.0: |
|
frame = cv2.resize(frame, (int(w*scale), int(h*scale))) |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
if first_frame is None: first_frame = frame |
|
frames.append(frame) |
|
frame_id += 1 |
|
cap.release() |
|
|
|
session.update({"first_frame": first_frame, "all_frames": frames, "frame_stride": stride, "scale_factor": scale, "inference_state": predictor.init_state(video_path=video_path), "input_points": [], "input_labels": []}) |
|
return gr.update(open=False), first_frame, None, gr.update(value=None, visible=False), session |
|
|
|
def show_mask(mask, obj_id=None): |
|
cmap = plt.get_cmap("tab10") |
|
color = np.array([*cmap(0 if obj_id is None else obj_id)[:3], 0.6]) |
|
h, w = mask.shape |
|
mask_rgba = (mask.reshape(h, w, 1) * color.reshape(1, 1, -1) * 255).astype(np.uint8) |
|
proper_mask = np.zeros((h, w, 4), dtype=np.uint8) |
|
proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)] |
|
return Image.fromarray(proper_mask, "RGBA") |
|
|
|
def segment_with_points(ptype, session, evt): |
|
session["input_points"].append(evt.index) |
|
session["input_labels"].append(1 if ptype == "include" else 0) |
|
first = session["first_frame"] |
|
h, w = first.shape[:2] |
|
|
|
layer = np.zeros((h, w, 4), dtype=np.uint8) |
|
for idx, pt in enumerate(session["input_points"]): |
|
color = (0, 255, 0, 255) if session["input_labels"][idx] == 1 else (255, 0, 0, 255) |
|
cv2.circle(layer, pt, int(min(w, h)*0.01), color, -1) |
|
|
|
overlay = Image.alpha_composite(Image.fromarray(first).convert("RGBA"), Image.fromarray(layer, "RGBA")) |
|
|
|
try: |
|
_, _, logits = predictor.add_new_points(session["inference_state"], 0, OBJ_ID, np.array(session["input_points"]), np.array(session["input_labels"])) |
|
mask = (logits[0] > 0.0).cpu().numpy() |
|
mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool) |
|
mask_img = show_mask(mask) |
|
return overlay, Image.alpha_composite(Image.fromarray(first).convert("RGBA"), mask_img), session |
|
except Exception as e: |
|
print("Segmentation error:", e) |
|
return overlay, overlay, session |
|
|
|
def propagate(video_in, session, progress=gr.Progress()): |
|
if not session["input_points"] or not session["inference_state"]: return None, session |
|
|
|
masks = {} |
|
for i, (idxs, obj_ids, logits) in enumerate(predictor.propagate_in_video(session["inference_state"])): |
|
try: |
|
masks[idxs] = {oid: (logits[j] > 0.0).cpu().numpy() for j, oid in enumerate(obj_ids)} |
|
progress(i / 300, desc=f"Tracking frame {idxs}") |
|
except: continue |
|
|
|
frames_out, stride = [], max(1, len(masks) // 50) |
|
for i in range(0, len(masks), stride): |
|
if i not in masks or OBJ_ID not in masks[i]: continue |
|
try: |
|
frame = session["all_frames"][i] |
|
mask = masks[i][OBJ_ID] |
|
h, w = frame.shape[:2] |
|
mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool) |
|
output = Image.alpha_composite(Image.fromarray(frame).convert("RGBA"), show_mask(mask)) |
|
frames_out.append(np.array(output)) |
|
except: continue |
|
|
|
out_path = os.path.join(tempfile.gettempdir(), f"output_video_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4") |
|
fps = min(15, get_fps(video_in)) |
|
ImageSequenceClip(frames_out, fps=fps).write_videofile(out_path, codec="libx264", bitrate="800k", threads=2, logger=None) |
|
gc.collect() |
|
return gr.update(value=out_path, visible=True), session |
|
|
|
with gr.Blocks() as demo: |
|
state = gr.State({"first_frame": None, "all_frames": None, "input_points": [], "input_labels": [], "inference_state": None, "frame_stride": 1, "scale_factor": 1.0, "original_dimensions": None}) |
|
|
|
gr.Markdown("<center><strong><font size='8'>EdgeTAM CPU</font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a></center>") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("""<ol><li>Upload a video or use an example</li><li>Select 'include' or 'exclude' and click points</li><li>Click 'Track' to segment and track</li></ol>""") |
|
drawer = gr.Accordion("Input Video", open=True) |
|
with drawer: |
|
video_in = gr.Video(label="Input Video", format="mp4") |
|
ptype = gr.Radio(label="Point Type", choices=["include", "exclude"], value="include") |
|
track_btn = gr.Button("Track", variant="primary") |
|
clear_btn = gr.Button("Clear Points") |
|
reset_btn = gr.Button("Reset") |
|
points_map = gr.Image(label="Frame with Points", type="numpy", interactive=False) |
|
with gr.Column(): |
|
gr.Markdown("# Try some examples ⬇️") |
|
gr.Examples(examples, inputs=[video_in], examples_per_page=5) |
|
output_img = gr.Image(label="Reference Mask") |
|
output_vid = gr.Video(visible=False) |
|
|
|
video_in.upload(preprocess_video, [video_in, state], [drawer, points_map, output_img, output_vid, state]) |
|
video_in.change(preprocess_video, [video_in, state], [drawer, points_map, output_img, output_vid, state]) |
|
points_map.select(segment_with_points, [ptype, state], [points_map, output_img, state]) |
|
clear_btn.click(clear_points, state, [points_map, output_img, output_vid, state]) |
|
reset_btn.click(reset, state, [video_in, drawer, points_map, output_img, output_vid, state]) |
|
track_btn.click(fn=propagate, inputs=[video_in, state], outputs=[output_vid, state]) |
|
|
|
if __name__ == '__main__': |
|
demo.queue() |
|
demo.launch() |