Luigi's picture
adaptation to ZeroGPU
991b8c0
raw
history blame
5.12 kB
import spaces
import time
import logging
import gradio as gr
import cv2
import os
from transformers import AutoProcessor, AutoModelForImageTextToText
import torch
from PIL import Image
# Cache for loaded model and processor
default_cache = {'model_id': None, 'processor': None, 'model': None, 'device': None}
model_cache = default_cache.copy()
# Check for XPU availability
has_xpu = hasattr(torch, 'xpu') and torch.xpu.is_available()
def update_model(model_id, device):
if model_cache['model_id'] != model_id or model_cache['device'] != device:
logging.info(f'Loading model {model_id} on {device}')
processor = AutoProcessor.from_pretrained(model_id)
# Load model with appropriate precision for each device
if device == 'cuda':
# Use bfloat16 for CUDA for performance
model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
_attn_implementation='flash_attention_2'
).to('cuda')
elif device == 'xpu' and has_xpu:
# Use float32 on XPU to avoid bfloat16 layernorm issues
model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.float32
).to('xpu')
else:
# Default to float32 on CPU
model = AutoModelForImageTextToText.from_pretrained(model_id).to('cpu')
model.eval()
model_cache.update({'model_id': model_id, 'processor': processor, 'model': model, 'device': device})
@spaces.GPU
def caption_frame(frame, model_id, interval_ms, sys_prompt, usr_prompt, device):
debug_msgs = []
update_model(model_id, device)
processor = model_cache['processor']
model = model_cache['model']
# Control capture interval
time.sleep(interval_ms / 1000)
# Preprocess frame
t0 = time.time()
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(rgb)
temp_path = 'frame.jpg'
pil_img.save(temp_path, format='JPEG', quality=50)
debug_msgs.append(f'Preprocess: {int((time.time()-t0)*1000)} ms')
# Prepare multimodal chat messages
messages = [
{'role': 'system', 'content': [{'type': 'text', 'text': sys_prompt}]},
{'role': 'user', 'content': [
{'type': 'image', 'url': temp_path},
{'type': 'text', 'text': usr_prompt}
]}
]
# Tokenize and encode
t1 = time.time()
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors='pt'
).to(model.device)
debug_msgs.append(f'Tokenize: {int((time.time()-t1)*1000)} ms')
# Inference
t2 = time.time()
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=128)
debug_msgs.append(f'Inference: {int((time.time()-t2)*1000)} ms')
# Decode and strip history
t3 = time.time()
raw = processor.batch_decode(outputs, skip_special_tokens=True)[0]
debug_msgs.append(f'Decode: {int((time.time()-t3)*1000)} ms')
if "Assistant:" in raw:
caption = raw.split("Assistant:")[-1].strip()
else:
lines = raw.splitlines()
caption = lines[-1].strip() if len(lines) > 1 else raw.strip()
return caption, '\n'.join(debug_msgs)
def main():
logging.basicConfig(level=logging.INFO)
model_choices = [
'HuggingFaceTB/SmolVLM2-256M-Video-Instruct',
'HuggingFaceTB/SmolVLM2-500M-Video-Instruct',
'HuggingFaceTB/SmolVLM2-2.2B-Instruct'
]
# Determine available devices
device_options = ['cpu']
if torch.cuda.is_available():
device_options.append('cuda')
if has_xpu:
device_options.append('xpu')
default_device = 'cuda' if torch.cuda.is_available() else ('xpu' if has_xpu else 'cpu')
with gr.Blocks() as demo:
gr.Markdown('## 🎥 Real-Time Webcam Captioning with SmolVLM2 (Transformers)')
with gr.Row():
model_dd = gr.Dropdown(model_choices, value=model_choices[0], label='Model ID')
device_dd = gr.Dropdown(device_options, value=default_device, label='Device')
interval = gr.Slider(100, 20000, step=100, value=3000, label='Interval (ms)')
sys_p = gr.Textbox(lines=2, value='Describe the key action', label='System Prompt')
usr_p = gr.Textbox(lines=1, value='What is happening in this image?', label='User Prompt')
cam = gr.Image(sources=['webcam'], streaming=True, label='Webcam Feed')
caption_tb = gr.Textbox(interactive=False, label='Caption')
log_tb = gr.Textbox(lines=4, interactive=False, label='Debug Log')
cam.stream(
fn=caption_frame,
inputs=[cam, model_dd, interval, sys_p, usr_p, device_dd],
outputs=[caption_tb, log_tb],
time_limit=600
)
# Enable Gradio's async event queue to register callback IDs and prevent KeyErrors
demo.queue()
# Launch the app
demo.launch()
if __name__ == '__main__':
main()