Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import tempfile | |
| import gradio as gr | |
| import cv2 | |
| try: | |
| from mmengine.visualization import Visualizer | |
| except ImportError: | |
| Visualizer = None | |
| print("Warning: mmengine is not installed, visualization is disabled.") | |
| # Load the model and tokenizer | |
| model_path = "ByteDance/Sa2VA-4B" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ).eval().cuda() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code = True, | |
| ) | |
| from third_parts import VideoReader | |
| def read_video(video_path, video_interval): | |
| vid_frames = VideoReader(video_path)[::video_interval] | |
| temp_dir = tempfile.mkdtemp() | |
| os.makedirs(temp_dir, exist_ok=True) | |
| image_paths = [] # List to store paths of saved images | |
| for frame_idx in range(len(vid_frames)): | |
| frame_image = vid_frames[frame_idx] | |
| frame_image = frame_image[..., ::-1] # BGR (opencv system) to RGB (numpy system) | |
| frame_image = Image.fromarray(frame_image) | |
| vid_frames[frame_idx] = frame_image | |
| # Save the frame as a .jpg file in the temporary folder | |
| image_path = os.path.join(temp_dir, f"frame_{frame_idx:04d}.jpg") | |
| frame_image.save(image_path, format="JPEG") | |
| # Append the image path to the list | |
| image_paths.append(image_path) | |
| return vid_frames, image_paths | |
| def visualize(pred_mask, image_path, work_dir): | |
| visualizer = Visualizer() | |
| img = cv2.imread(image_path) | |
| visualizer.set_image(img) | |
| visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4) | |
| visual_result = visualizer.get_image() | |
| output_path = os.path.join(work_dir, os.path.basename(image_path)) | |
| cv2.imwrite(output_path, visual_result) | |
| return output_path | |
| def image_vision(image_input_path, prompt): | |
| image_path = image_input_path | |
| text_prompts = f"<image>{prompt}" | |
| image = Image.open(image_path).convert('RGB') | |
| input_dict = { | |
| 'image': image, | |
| 'text': text_prompts, | |
| 'past_text': '', | |
| 'mask_prompts': None, | |
| 'tokenizer': tokenizer, | |
| } | |
| return_dict = model.predict_forward(**input_dict) | |
| print(return_dict) | |
| answer = return_dict["prediction"] # the text format answer | |
| seg_image = return_dict["prediction_masks"] | |
| if '[SEG]' in answer and Visualizer is not None: | |
| pred_masks = seg_image[0] | |
| temp_dir = tempfile.mkdtemp() | |
| pred_mask = pred_masks | |
| os.makedirs(temp_dir, exist_ok=True) | |
| seg_result = visualize(pred_mask, image_input_path, temp_dir) | |
| return answer, seg_result | |
| else: | |
| return answer, None | |
| def video_vision(video_input_path, prompt, video_interval): | |
| # Open the original video | |
| cap = cv2.VideoCapture(video_input_path) | |
| # Get original video properties | |
| original_fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_skip_factor = video_interval | |
| # Calculate new FPS | |
| new_fps = original_fps / frame_skip_factor | |
| vid_frames, image_paths = read_video(video_input_path, video_interval) | |
| # create a question (<image> is a placeholder for the video frames) | |
| question = f"<image>{prompt}" | |
| result = model.predict_forward( | |
| video=vid_frames, | |
| text=question, | |
| tokenizer=tokenizer, | |
| ) | |
| prediction = result['prediction'] | |
| print(prediction) | |
| if '[SEG]' in prediction and Visualizer is not None: | |
| _seg_idx = 0 | |
| pred_masks = result['prediction_masks'][_seg_idx] | |
| seg_frames = [] | |
| for frame_idx in range(len(vid_frames)): | |
| pred_mask = pred_masks[frame_idx] | |
| temp_dir = tempfile.mkdtemp() | |
| os.makedirs(temp_dir, exist_ok=True) | |
| seg_frame = visualize(pred_mask, image_paths[frame_idx], temp_dir) | |
| seg_frames.append(seg_frame) | |
| output_video = "output_video.mp4" | |
| # Read the first image to get the size (resolution) | |
| frame = cv2.imread(seg_frames[0]) | |
| height, width, layers = frame.shape | |
| # Define the video codec and create VideoWriter object | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for MP4 | |
| video = cv2.VideoWriter(output_video, fourcc, new_fps, (width, height)) | |
| # Iterate over the image paths and write to the video | |
| for img_path in seg_frames: | |
| frame = cv2.imread(img_path) | |
| video.write(frame) | |
| # Release the video writer | |
| video.release() | |
| print(f"Video created successfully at {output_video}") | |
| return result['prediction'], output_video | |
| else: | |
| return result['prediction'], None | |
| # Gradio UI | |
| with gr.Blocks(analytics_enabled=False) as demo: | |
| with gr.Column(): | |
| gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos") | |
| gr.HTML(""" | |
| <div style="display:flex;column-gap:4px;"> | |
| <a href="https://github.com/magic-research/Sa2VA"> | |
| <img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
| </a> | |
| <a href="https://arxiv.org/abs/2501.04001"> | |
| <img src='https://img.shields.io/badge/ArXiv-Paper-red'> | |
| </a> | |
| <a href="https://huggingface.co/spaces/fffiloni/Sa2VA-simple-demo?duplicate=true"> | |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space"> | |
| </a> | |
| <a href="https://huggingface.co/fffiloni"> | |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF"> | |
| </a> | |
| </div> | |
| """) | |
| with gr.Tab("Single Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Image IN", type="filepath") | |
| with gr.Row(): | |
| instruction = gr.Textbox(label="Instruction", scale=4) | |
| submit_image_btn = gr.Button("Submit", scale=1) | |
| with gr.Column(): | |
| output_res = gr.Textbox(label="Response") | |
| output_image = gr.Image(label="Segmentation", type="numpy") | |
| submit_image_btn.click( | |
| fn = image_vision, | |
| inputs = [image_input, instruction], | |
| outputs = [output_res, output_image] | |
| ) | |
| with gr.Tab("Video"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="Video IN") | |
| frame_interval = gr.Slider(label="Frame interval", step=1, minimum=1, maximum=12, value=6) | |
| with gr.Row(): | |
| vid_instruction = gr.Textbox(label="Instruction", scale=4) | |
| submit_video_btn = gr.Button("Submit", scale=1) | |
| with gr.Column(): | |
| vid_output_res = gr.Textbox(label="Response") | |
| output_video = gr.Video(label="Segmentation") | |
| submit_video_btn.click( | |
| fn = video_vision, | |
| inputs = [video_input, vid_instruction, frame_interval], | |
| outputs = [vid_output_res, output_video] | |
| ) | |
| demo.queue().launch(show_api=False, show_error=True) |