DHEIVER's picture
Update app.py
ded1ee1 verified
import os
import gradio as gr
import cv2
import numpy as np
import torch
import requests
from segment_anything import sam_model_registry, SamPredictor
def download_sam_model():
model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
checkpoint_path = "sam_vit_h_4b8939.pth"
if not os.path.exists(checkpoint_path):
print("Downloading SAM model...")
response = requests.get(model_url, stream=True)
with open(checkpoint_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("Download complete!")
return checkpoint_path
def process_video_sam(video_path, progress=gr.Progress()):
checkpoint_path = download_sam_model()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sam = sam_model_registry["vit_h"](checkpoint=checkpoint_path)
sam.to(device=DEVICE)
predictor = SamPredictor(sam)
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
output_path = "output_video.mp4"
out = cv2.VideoWriter(output_path,
cv2.VideoWriter_fourcc(*'mp4v'),
fps,
(width, height))
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
predictor.set_image(frame)
# Gerar pontos de prompt automáticos
input_point = np.array([[width//2, height//2]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True
)
annotated_frame = frame.copy()
for mask in masks:
annotated_frame[mask] = annotated_frame[mask] * 0.5 + np.array([0, 255, 0]) * 0.5
out.write(annotated_frame)
frame_count += 1
progress(frame_count/total_frames, desc="Processing video...")
cap.release()
out.release()
return output_path
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Video Segmentation with SAM")
gr.Markdown("Upload a video to segment objects using Segment Anything Model")
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Input Video")
process_btn = gr.Button("Process Video", variant="primary")
with gr.Column():
output_video = gr.Video(label="Segmented Video")
process_btn.click(
fn=process_video_sam,
inputs=input_video,
outputs=output_video,
api_name="segment_video"
)
if __name__ == "__main__":
demo.launch()