DHEIVER's picture
Create app.py
e985675 verified
raw
history blame
2.28 kB
import os
import gradio as gr
import cv2
import numpy as np
import torch
import requests
from segment_anything import sam_model_registry, SamPredictor
def download_sam_model():
model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
checkpoint_path = "sam_vit_h_4b8939.pth"
if not os.path.exists(checkpoint_path):
print("Downloading SAM model...")
response = requests.get(model_url, stream=True)
with open(checkpoint_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("Download complete!")
return checkpoint_path
def process_video_sam(video_path):
# Download model if needed
checkpoint_path = download_sam_model()
# Initialize SAM
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"
sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint_path)
sam.to(device=DEVICE)
predictor = SamPredictor(sam)
# Process video
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
output_path = "output_video.mp4"
out = cv2.VideoWriter(output_path,
cv2.VideoWriter_fourcc(*'mp4v'),
fps,
(width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
predictor.set_image(frame)
masks = predictor.generate()
annotated_frame = frame.copy()
for mask in masks[0]:
annotated_frame[mask.mask] = annotated_frame[mask.mask] * 0.5 + np.array([0, 255, 0]) * 0.5
out.write(annotated_frame)
cap.release()
out.release()
return output_path
iface = gr.Interface(
fn=process_video_sam,
inputs=gr.Video(label="Upload Video"),
outputs=gr.Video(label="Segmented Video"),
title="Video Segmentation with SAM",
description="Upload a video to segment objects using Segment Anything Model"
)
if __name__ == "__main__":
iface.launch()