|
import os |
|
import gradio as gr |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import requests |
|
from segment_anything import sam_model_registry, SamPredictor |
|
|
|
def download_sam_model(): |
|
model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" |
|
checkpoint_path = "sam_vit_h_4b8939.pth" |
|
|
|
if not os.path.exists(checkpoint_path): |
|
print("Downloading SAM model...") |
|
response = requests.get(model_url, stream=True) |
|
with open(checkpoint_path, "wb") as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
if chunk: |
|
f.write(chunk) |
|
print("Download complete!") |
|
return checkpoint_path |
|
|
|
def process_video_sam(video_path, progress=gr.Progress()): |
|
checkpoint_path = download_sam_model() |
|
|
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
sam = sam_model_registry["vit_h"](checkpoint=checkpoint_path) |
|
sam.to(device=DEVICE) |
|
predictor = SamPredictor(sam) |
|
|
|
cap = cv2.VideoCapture(video_path) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
output_path = "output_video.mp4" |
|
out = cv2.VideoWriter(output_path, |
|
cv2.VideoWriter_fourcc(*'mp4v'), |
|
fps, |
|
(width, height)) |
|
|
|
frame_count = 0 |
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
predictor.set_image(frame) |
|
|
|
input_point = np.array([[width//2, height//2]]) |
|
input_label = np.array([1]) |
|
|
|
masks, scores, logits = predictor.predict( |
|
point_coords=input_point, |
|
point_labels=input_label, |
|
multimask_output=True |
|
) |
|
|
|
annotated_frame = frame.copy() |
|
for mask in masks: |
|
annotated_frame[mask] = annotated_frame[mask] * 0.5 + np.array([0, 255, 0]) * 0.5 |
|
|
|
out.write(annotated_frame) |
|
|
|
frame_count += 1 |
|
progress(frame_count/total_frames, desc="Processing video...") |
|
|
|
cap.release() |
|
out.release() |
|
|
|
return output_path |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# Video Segmentation with SAM") |
|
gr.Markdown("Upload a video to segment objects using Segment Anything Model") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_video = gr.Video(label="Input Video") |
|
process_btn = gr.Button("Process Video", variant="primary") |
|
|
|
with gr.Column(): |
|
output_video = gr.Video(label="Segmented Video") |
|
|
|
process_btn.click( |
|
fn=process_video_sam, |
|
inputs=input_video, |
|
outputs=output_video, |
|
api_name="segment_video" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |