Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from PIL import Image | |
import numpy as np | |
import os | |
import tempfile | |
import spaces | |
import gradio as gr | |
import subprocess | |
import sys | |
import cv2 | |
import threading | |
import queue | |
import time | |
from collections import deque | |
from deep_translator import GoogleTranslator | |
def install_flash_attn_wheel(): | |
flash_attn_wheel_url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl" | |
try: | |
subprocess.check_call([sys.executable, "-m", "pip", "install", flash_attn_wheel_url]) | |
print("Wheel installed successfully!") | |
except subprocess.CalledProcessError as e: | |
print(f"Failed to install the flash attnetion wheel. Error: {e}") | |
install_flash_attn_wheel() | |
try: | |
from mmengine.visualization import Visualizer | |
except ImportError: | |
Visualizer = None | |
print("Warning: mmengine is not installed, visualization is disabled.") | |
# Load the model and tokenizer | |
model_path = "ByteDance/Sa2VA-4B" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype="auto", | |
device_map="cuda:0", | |
trust_remote_code=True, | |
).eval().cuda() | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
trust_remote_code = True, | |
) | |
class WebcamProcessor: | |
def __init__(self, model, tokenizer, fps_target=15, buffer_size=5): | |
self.model = model | |
self.tokenizer = tokenizer | |
self.fps_target = fps_target | |
self.frame_interval = 1.0 / fps_target | |
self.buffer_size = buffer_size | |
self.frame_buffer = deque(maxlen=buffer_size) | |
self.result_queue = queue.Queue() | |
self.is_running = False | |
self.last_process_time = 0 | |
def start(self): | |
try: | |
self.is_running = True | |
self.capture = cv2.VideoCapture(0) | |
if not self.capture.isOpened(): | |
raise Exception("Failed to open webcam") | |
# Set camera properties | |
self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, 640) | |
self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) | |
self.capture_thread = threading.Thread(target=self._capture_loop) | |
self.process_thread = threading.Thread(target=self._process_loop) | |
self.capture_thread.daemon = True | |
self.process_thread.daemon = True | |
self.capture_thread.start() | |
self.process_thread.start() | |
return "Webcam started successfully" | |
except Exception as e: | |
self.is_running = False | |
return f"Failed to start webcam: {str(e)}" | |
def stop(self): | |
try: | |
self.is_running = False | |
if hasattr(self, 'capture_thread'): | |
self.capture_thread.join(timeout=1.0) | |
if hasattr(self, 'process_thread'): | |
self.process_thread.join(timeout=1.0) | |
if hasattr(self, 'capture'): | |
self.capture.release() | |
return "Webcam stopped successfully" | |
except Exception as e: | |
return f"Error stopping webcam: {str(e)}" | |
def _capture_loop(self): | |
while self.is_running: | |
try: | |
ret, frame = self.capture.read() | |
if ret: | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frame = cv2.resize(frame, (640, 480)) | |
current_time = time.time() | |
if current_time - self.last_process_time >= self.frame_interval: | |
self.frame_buffer.append(frame) | |
self.last_process_time = current_time | |
time.sleep(0.01) # Small delay to prevent CPU overuse | |
except Exception as e: | |
print(f"Capture error: {e}") | |
time.sleep(0.1) | |
def _process_loop(self): | |
while self.is_running: | |
try: | |
if len(self.frame_buffer) >= self.buffer_size: | |
frames = list(self.frame_buffer) | |
result = self.model.predict_forward( | |
video=frames, | |
text="<image>Describe what you see", | |
tokenizer=self.tokenizer | |
) | |
self.result_queue.put(result) | |
self.frame_buffer.clear() | |
time.sleep(0.1) | |
except Exception as e: | |
print(f"Processing error: {e}") | |
time.sleep(0.1) | |
from third_parts import VideoReader | |
def read_video(video_path, video_interval): | |
vid_frames = VideoReader(video_path)[::video_interval] | |
temp_dir = tempfile.mkdtemp() | |
os.makedirs(temp_dir, exist_ok=True) | |
image_paths = [] | |
for frame_idx in range(len(vid_frames)): | |
frame_image = vid_frames[frame_idx] | |
frame_image = frame_image[..., ::-1] | |
frame_image = Image.fromarray(frame_image) | |
vid_frames[frame_idx] = frame_image | |
image_path = os.path.join(temp_dir, f"frame_{frame_idx:04d}.jpg") | |
frame_image.save(image_path, format="JPEG") | |
image_paths.append(image_path) | |
return vid_frames, image_paths | |
def visualize(pred_mask, image_path, work_dir): | |
visualizer = Visualizer() | |
img = cv2.imread(image_path) | |
visualizer.set_image(img) | |
visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4) | |
visual_result = visualizer.get_image() | |
output_path = os.path.join(work_dir, os.path.basename(image_path)) | |
cv2.imwrite(output_path, visual_result) | |
return output_path | |
def translate_to_korean(text): | |
try: | |
translator = GoogleTranslator(source='en', target='ko') | |
return translator.translate(text) | |
except Exception as e: | |
print(f"Translation error: {e}") | |
return text | |
def image_vision(image_input_path, prompt): | |
is_korean = any(ord('κ°') <= ord(char) <= ord('ν£') for char in prompt) | |
image_path = image_input_path | |
text_prompts = f"<image>{prompt}" | |
image = Image.open(image_path).convert('RGB') | |
input_dict = { | |
'image': image, | |
'text': text_prompts, | |
'past_text': '', | |
'mask_prompts': None, | |
'tokenizer': tokenizer, | |
} | |
return_dict = model.predict_forward(**input_dict) | |
print(return_dict) | |
answer = return_dict["prediction"] | |
if is_korean: | |
if '[SEG]' in answer: | |
parts = answer.split('[SEG]') | |
translated_parts = [translate_to_korean(part.strip()) for part in parts] | |
answer = '[SEG]'.join(translated_parts) | |
else: | |
answer = translate_to_korean(answer) | |
seg_image = return_dict["prediction_masks"] | |
if '[SEG]' in answer and Visualizer is not None: | |
pred_masks = seg_image[0] | |
temp_dir = tempfile.mkdtemp() | |
pred_mask = pred_masks | |
os.makedirs(temp_dir, exist_ok=True) | |
seg_result = visualize(pred_mask, image_input_path, temp_dir) | |
return answer, seg_result | |
else: | |
return answer, None | |
def video_vision(video_input_path, prompt, video_interval): | |
is_korean = any(ord('κ°') <= ord(char) <= ord('ν£') for char in prompt) | |
cap = cv2.VideoCapture(video_input_path) | |
original_fps = cap.get(cv2.CAP_PROP_FPS) | |
frame_skip_factor = video_interval | |
new_fps = original_fps / frame_skip_factor | |
vid_frames, image_paths = read_video(video_input_path, video_interval) | |
question = f"<image>{prompt}" | |
result = model.predict_forward( | |
video=vid_frames, | |
text=question, | |
tokenizer=tokenizer, | |
) | |
prediction = result['prediction'] | |
print(prediction) | |
if is_korean: | |
if '[SEG]' in prediction: | |
parts = prediction.split('[SEG]') | |
translated_parts = [translate_to_korean(part.strip()) for part in parts] | |
prediction = '[SEG]'.join(translated_parts) | |
else: | |
prediction = translate_to_korean(prediction) | |
if '[SEG]' in prediction and Visualizer is not None: | |
_seg_idx = 0 | |
pred_masks = result['prediction_masks'][_seg_idx] | |
seg_frames = [] | |
for frame_idx in range(len(vid_frames)): | |
pred_mask = pred_masks[frame_idx] | |
temp_dir = tempfile.mkdtemp() | |
os.makedirs(temp_dir, exist_ok=True) | |
seg_frame = visualize(pred_mask, image_paths[frame_idx], temp_dir) | |
seg_frames.append(seg_frame) | |
output_video = "output_video.mp4" | |
frame = cv2.imread(seg_frames[0]) | |
height, width, layers = frame.shape | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
video = cv2.VideoWriter(output_video, fourcc, new_fps, (width, height)) | |
for img_path in seg_frames: | |
frame = cv2.imread(img_path) | |
video.write(frame) | |
video.release() | |
print(f"Video created successfully at {output_video}") | |
return prediction, output_video | |
else: | |
return prediction, None | |
def webcam_vision(prompt): | |
try: | |
if not hasattr(webcam_vision, 'processor'): | |
webcam_vision.processor = WebcamProcessor(model, tokenizer) | |
if not webcam_vision.processor.is_running: | |
status = webcam_vision.processor.start() | |
if "Failed" in status: | |
return f"Error: {status}" | |
try: | |
result = webcam_vision.processor.result_queue.get(timeout=5) | |
prediction = result['prediction'] | |
# Check if Korean translation is needed | |
is_korean = any(ord('κ°') <= ord(char) <= ord('ν£') for char in prompt) | |
if is_korean: | |
prediction = translate_to_korean(prediction) | |
return prediction | |
except queue.Empty: | |
return "No results available yet. Please try again." | |
except Exception as e: | |
return f"Processing error: {str(e)}" | |
except Exception as e: | |
return f"System error: {str(e)}" | |
# Gradio UI | |
with gr.Blocks(analytics_enabled=False) as demo: | |
with gr.Column(): | |
gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos") | |
with gr.Tab("Single Image"): | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(label="Image IN", type="filepath") | |
with gr.Row(): | |
instruction = gr.Textbox(label="Instruction", scale=4) | |
submit_image_btn = gr.Button("Submit", scale=1) | |
with gr.Column(): | |
output_res = gr.Textbox(label="Response") | |
output_image = gr.Image(label="Segmentation", type="numpy") | |
submit_image_btn.click( | |
fn = image_vision, | |
inputs = [image_input, instruction], | |
outputs = [output_res, output_image] | |
) | |
with gr.Tab("Video"): | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video(label="Video IN") | |
frame_interval = gr.Slider(label="Frame interval", step=1, minimum=1, maximum=12, value=6) | |
with gr.Row(): | |
vid_instruction = gr.Textbox(label="Instruction", scale=4) | |
submit_video_btn = gr.Button("Submit", scale=1) | |
with gr.Column(): | |
vid_output_res = gr.Textbox(label="Response") | |
output_video = gr.Video(label="Segmentation") | |
submit_video_btn.click( | |
fn = video_vision, | |
inputs = [video_input, vid_instruction, frame_interval], | |
outputs = [vid_output_res, output_video] | |
) | |
with gr.Tab("Webcam"): | |
with gr.Row(): | |
with gr.Column(): | |
# μΉμΊ μ λ ₯μ μν μ»΄ν¬λνΈ | |
webcam_input = gr.Image( | |
label="Webcam Input", | |
type="numpy", | |
sources="webcam", | |
streaming=True, | |
mirror_webcam=True | |
) | |
with gr.Row(): | |
webcam_instruction = gr.Textbox( | |
label="Instruction", | |
placeholder="Enter instruction here...", | |
scale=4 | |
) | |
start_button = gr.Button("Start", scale=1) | |
stop_button = gr.Button("Stop", scale=1) | |
with gr.Column(): | |
webcam_output = gr.Textbox(label="Response") | |
processed_view = gr.Image(label="Processed View") | |
status_text = gr.Textbox(label="Status", value="Ready") | |
def start_webcam_processing(instruction): | |
try: | |
if hasattr(webcam_vision, 'processor'): | |
webcam_vision.processor.stop() | |
webcam_vision.processor = WebcamProcessor(model, tokenizer) | |
status = webcam_vision.processor.start() | |
return webcam_vision(instruction) | |
except Exception as e: | |
return f"Error starting webcam: {str(e)}" | |
start_button.click( | |
fn=start_webcam_processing, | |
inputs=[webcam_instruction], | |
outputs=[webcam_output] | |
) | |
stop_button.click( | |
fn=lambda: "Stopped" if hasattr(webcam_vision, 'processor') and webcam_vision.processor.stop() else "Not running", | |
outputs=[status_text] | |
) | |
# μΉμΊ μ‘μΈμ€λ₯Ό μν μ€μ μΆκ° | |
demo.queue().launch( | |
server_name="0.0.0.0", # λͺ¨λ IPμμ μ κ·Ό κ°λ₯ | |
server_port=7860, # ν¬νΈ μ§μ | |
share=True, # κ³΅κ° λ§ν¬ μμ± | |
show_api=False, | |
show_error=True | |
) |