Spaces:
Runtime error
Runtime error
import torch | |
import time | |
import uuid | |
from typing import Tuple | |
import gradio as gr | |
import supervision as sv | |
import numpy as np | |
from tqdm import tqdm | |
from transformers import pipeline | |
from PIL import Image | |
START_FRAME = 0 | |
END_FRAME = 10 | |
TOTAL = END_FRAME - START_FRAME | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
SAM_GENERATOR = pipeline( | |
task="mask-generation", | |
model="facebook/sam-vit-base", | |
device=DEVICE) | |
MASK_ANNOTATOR = sv.MaskAnnotator( | |
color=sv.Color.red(), | |
color_lookup=sv.ColorLookup.INDEX) | |
def run_sam(frame: np.ndarray) -> sv.Detections: | |
# convert from Numpy BGR to PIL RGB | |
image = Image.fromarray(frame[:, :, ::-1]) | |
outputs = SAM_GENERATOR(image) | |
mask = np.array(outputs['masks']) | |
return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask) | |
def mask_video(source_video: str, prompt: str, confidence: float, name: str) -> str: | |
video_info = sv.VideoInfo.from_video_path(source_video) | |
frame_iterator = iter(sv.get_video_frames_generator( | |
source_path=source_video, start=START_FRAME, end=END_FRAME)) | |
with sv.VideoSink(f"{name}.mp4", video_info=video_info) as sink: | |
for _ in tqdm(range(TOTAL), desc="Masking frames"): | |
frame = next(frame_iterator) | |
detections = run_sam(frame) | |
annotated_frame = MASK_ANNOTATOR.annotate( | |
scene=frame.copy(), detections=detections) | |
sink.write_frame(annotated_frame) | |
return f"{name}.mp4" | |
def process( | |
source_video: str, | |
prompt: str, | |
confidence: float, | |
progress=gr.Progress(track_tqdm=True) | |
) -> Tuple[str, str]: | |
name = str(uuid.uuid4()) | |
masked_video = mask_video(source_video, prompt, confidence, name) | |
return masked_video, masked_video | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
source_video_player = gr.Video( | |
label="Source video", source="upload", format="mp4") | |
prompt_text = gr.Textbox( | |
label="Prompt", value="person") | |
confidence_slider = gr.Slider( | |
label="Confidence", minimum=0.5, maximum=1.0, step=0.05, value=0.6) | |
submit_button = gr.Button("Submit") | |
with gr.Column(): | |
masked_video_player = gr.Video(label="Masked video") | |
painted_video_player = gr.Video(label="Painted video") | |
submit_button.click( | |
process, | |
inputs=[source_video_player, prompt_text, confidence_slider], | |
outputs=[masked_video_player, painted_video_player]) | |
demo.queue().launch(debug=False, show_error=True) | |