import os import cv2 import torch import numpy as np from tqdm import tqdm from torchvision import transforms import imageio import argparse import sys sys.path.append("RAFT/core") from raft import RAFT from utils.utils import InputPadder DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def load_raft_model(ckpt_path): args = argparse.Namespace( small=False, mixed_precision=False, alternate_corr=False, dropout=0.0, max_depth=8, depth_network=False, depth_residual=False, depth_scale=1.0 ) model = torch.nn.DataParallel(RAFT(args)) model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE)) return model.module.to(DEVICE).eval() def run_masking(video_path, output_path, mask_path, raft): cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print(f"Failed to open video: {video_path}") return fps = cap.get(cv2.CAP_PROP_FPS) n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) ok, first = cap.read() if not ok: print(f"Failed to read first frame in {video_path}") return resize_to = (720, 480) first = cv2.resize(first, resize_to) H, W, _ = first.shape area_thresh = (H * W) // 6 grid = np.stack(np.meshgrid(np.arange(W), np.arange(H)), -1).astype(np.float32) pos = grid.copy() vis = np.ones((H, W), dtype=bool) writer = imageio.get_writer(output_path, fps=int(fps)) prev = first.copy() frames_since_corr = 0 freeze_mask = False frozen_mask = None all_masks = [] writer.append_data(first[:, :, ::-1]) all_masks.append(np.ones((H, W), dtype=bool)) def to_tensor(bgr): return transforms.ToTensor()(bgr).unsqueeze(0).to(DEVICE) def raft_flow(img1_bgr, img2_bgr): t1, t2 = to_tensor(img1_bgr), to_tensor(img2_bgr) padder = InputPadder(t1.shape) i1, i2 = padder.pad(t1, t2) with torch.no_grad(): _, flow = raft(i1, i2, iters=20, test_mode=True) return padder.unpad(flow)[0].permute(1, 2, 0).cpu().numpy() for _ in range(1, n_frames): ok, cur = cap.read() if not ok: break cur = cv2.resize(cur, resize_to) if not freeze_mask: flow_fw = raft_flow(prev, cur) pos += flow_fw frames_since_corr += 1 x_ok = (0 <= pos[..., 0]) & (pos[..., 0] < W) y_ok = (0 <= pos[..., 1]) & (pos[..., 1] < H) vis &= x_ok & y_ok m = np.zeros((H, W), np.uint8) ys, xs = np.where(vis) px = np.round(pos[ys, xs, 0]).astype(int) py = np.round(pos[ys, xs, 1]).astype(int) inb = (0 <= px) & (px < W) & (0 <= py) & (py < H) m[py[inb], px[inb]] = 1 m = cv2.dilate(m, np.ones((2, 2), np.uint8)) visible_ratio = m.sum() / (H * W) if visible_ratio < 0.3: flow_0t = raft_flow(first, cur) pos = grid + flow_0t vis = np.ones((H, W), dtype=bool) x_ok = (0 <= pos[..., 0]) & (pos[..., 0] < W) y_ok = (0 <= pos[..., 1]) & (pos[..., 1] < H) vis &= x_ok & y_ok m.fill(0) ys, xs = np.where(vis) px = np.round(pos[ys, xs, 0]).astype(int) py = np.round(pos[ys, xs, 1]).astype(int) inb = (0 <= px) & (px < W) & (0 <= py) & (py < H) m[py[inb], px[inb]] = 1 m = cv2.dilate(m, np.ones((2, 2), np.uint8)) if m.sum() < area_thresh: freeze_mask = True frozen_mask = m.copy() frames_since_corr = 0 else: m = frozen_mask effective_mask = m.astype(bool) all_masks.append(effective_mask) out = cur.copy() out[~effective_mask] = 0 writer.append_data(out[:, :, ::-1]) prev = cur if not freeze_mask else prev writer.close() cap.release() all_masks_array = np.stack(all_masks, axis=0) np.savez_compressed(mask_path, mask=all_masks_array) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--video_path", type=str, required=True) parser.add_argument("--output_path", type=str, required=True) parser.add_argument("--mask_path", type=str, required=True) parser.add_argument("--raft_ckpt", type=str, required=True) parser.add_argument("--start_idx", type=int, required=True) parser.add_argument("--end_idx", type=int, required=True) parser.add_argument("--gpu_id", type=int, required=True) args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" os.makedirs(args.output_path, exist_ok=True) os.makedirs(args.mask_path, exist_ok=True) video_list = sorted([ f for f in os.listdir(args.video_path) if f.endswith(".mp4") ]) selected_videos = video_list[args.start_idx : args.end_idx] print(f"[GPU {args.gpu_id}] Processing {len(selected_videos)} videos: {args.start_idx} to {args.end_idx}") model = load_raft_model(args.raft_ckpt) for fname in tqdm(selected_videos, desc="Batch Processing"): input_path = os.path.join(args.video_path, fname) mask_path = os.path.join(args.mask_path, fname.replace(".mp4", ".npz")) output_path = os.path.join(args.output_path, fname) if os.path.exists(mask_path): try: np.load(mask_path)["mask"] continue except: print(f"⚠️ Mask corrupt or unreadable: {mask_path} - Regenerating") if os.path.exists(output_path): continue run_masking(input_path, output_path, mask_path, model)