|
import os |
|
import gradio as gr |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import requests |
|
from segment_anything import sam_model_registry, SamPredictor |
|
|
|
def download_sam_model(): |
|
model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" |
|
checkpoint_path = "sam_vit_h_4b8939.pth" |
|
|
|
if not os.path.exists(checkpoint_path): |
|
print("Downloading SAM model...") |
|
response = requests.get(model_url, stream=True) |
|
with open(checkpoint_path, "wb") as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
if chunk: |
|
f.write(chunk) |
|
print("Download complete!") |
|
return checkpoint_path |
|
|
|
def process_video_sam(video_path): |
|
|
|
checkpoint_path = download_sam_model() |
|
|
|
|
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
MODEL_TYPE = "vit_h" |
|
|
|
sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint_path) |
|
sam.to(device=DEVICE) |
|
predictor = SamPredictor(sam) |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
output_path = "output_video.mp4" |
|
out = cv2.VideoWriter(output_path, |
|
cv2.VideoWriter_fourcc(*'mp4v'), |
|
fps, |
|
(width, height)) |
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
predictor.set_image(frame) |
|
masks = predictor.generate() |
|
|
|
annotated_frame = frame.copy() |
|
for mask in masks[0]: |
|
annotated_frame[mask.mask] = annotated_frame[mask.mask] * 0.5 + np.array([0, 255, 0]) * 0.5 |
|
|
|
out.write(annotated_frame) |
|
|
|
cap.release() |
|
out.release() |
|
|
|
return output_path |
|
|
|
iface = gr.Interface( |
|
fn=process_video_sam, |
|
inputs=gr.Video(label="Upload Video"), |
|
outputs=gr.Video(label="Segmented Video"), |
|
title="Video Segmentation with SAM", |
|
description="Upload a video to segment objects using Segment Anything Model" |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |