SkalskiP's picture
Add mask generation to video processing pipeline
b643479
raw
history blame
2.59 kB
import torch
import time
import uuid
from typing import Tuple
import gradio as gr
import supervision as sv
import numpy as np
from tqdm import tqdm
from transformers import pipeline
from PIL import Image
START_FRAME = 0
END_FRAME = 10
TOTAL = END_FRAME - START_FRAME
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAM_GENERATOR = pipeline(
task="mask-generation",
model="facebook/sam-vit-base",
device=DEVICE)
MASK_ANNOTATOR = sv.MaskAnnotator(
color=sv.Color.red(),
color_lookup=sv.ColorLookup.INDEX)
def run_sam(frame: np.ndarray) -> sv.Detections:
# convert from Numpy BGR to PIL RGB
image = Image.fromarray(frame[:, :, ::-1])
outputs = SAM_GENERATOR(image)
mask = np.array(outputs['masks'])
return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
def mask_video(source_video: str, prompt: str, confidence: float, name: str) -> str:
video_info = sv.VideoInfo.from_video_path(source_video)
frame_iterator = iter(sv.get_video_frames_generator(
source_path=source_video, start=START_FRAME, end=END_FRAME))
with sv.VideoSink(f"{name}.mp4", video_info=video_info) as sink:
for _ in tqdm(range(TOTAL), desc="Masking frames"):
frame = next(frame_iterator)
detections = run_sam(frame)
annotated_frame = MASK_ANNOTATOR.annotate(
scene=frame.copy(), detections=detections)
sink.write_frame(annotated_frame)
return f"{name}.mp4"
def process(
source_video: str,
prompt: str,
confidence: float,
progress=gr.Progress(track_tqdm=True)
) -> Tuple[str, str]:
name = str(uuid.uuid4())
masked_video = mask_video(source_video, prompt, confidence, name)
return masked_video, masked_video
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
source_video_player = gr.Video(
label="Source video", source="upload", format="mp4")
prompt_text = gr.Textbox(
label="Prompt", value="person")
confidence_slider = gr.Slider(
label="Confidence", minimum=0.5, maximum=1.0, step=0.05, value=0.6)
submit_button = gr.Button("Submit")
with gr.Column():
masked_video_player = gr.Video(label="Masked video")
painted_video_player = gr.Video(label="Painted video")
submit_button.click(
process,
inputs=[source_video_player, prompt_text, confidence_slider],
outputs=[masked_video_player, painted_video_player])
demo.queue().launch(debug=False, show_error=True)