import os import gradio as gr import cv2 import numpy as np import torch import requests from segment_anything import sam_model_registry, SamPredictor def download_sam_model(): model_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" checkpoint_path = "sam_vit_h_4b8939.pth" if not os.path.exists(checkpoint_path): print("Downloading SAM model...") response = requests.get(model_url, stream=True) with open(checkpoint_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): if chunk: f.write(chunk) print("Download complete!") return checkpoint_path def process_video_sam(video_path, progress=gr.Progress()): checkpoint_path = download_sam_model() DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') sam = sam_model_registry["vit_h"](checkpoint=checkpoint_path) sam.to(device=DEVICE) predictor = SamPredictor(sam) cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) output_path = "output_video.mp4" out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) frame_count = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break predictor.set_image(frame) # Gerar pontos de prompt automáticos input_point = np.array([[width//2, height//2]]) input_label = np.array([1]) masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True ) annotated_frame = frame.copy() for mask in masks: annotated_frame[mask] = annotated_frame[mask] * 0.5 + np.array([0, 255, 0]) * 0.5 out.write(annotated_frame) frame_count += 1 progress(frame_count/total_frames, desc="Processing video...") cap.release() out.release() return output_path with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# Video Segmentation with SAM") gr.Markdown("Upload a video to segment objects using Segment Anything Model") with gr.Row(): with gr.Column(): input_video = gr.Video(label="Input Video") process_btn = gr.Button("Process Video", variant="primary") with gr.Column(): output_video = gr.Video(label="Segmented Video") process_btn.click( fn=process_video_sam, inputs=input_video, outputs=output_video, api_name="segment_video" ) if __name__ == "__main__": demo.launch()