Revisual-R1 / app.py
cyrus28214's picture
update
cde52cf
raw
history blame
2.56 kB
import gradio as gr
import torch
from PIL import Image
from threading import Thread
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
import spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
MODEL_ID = "HuggingFaceTB/SmolVLM-256M-Instruct"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForVision2Seq.from_pretrained(
MODEL_ID,
torch_dtype=torch_dtype,
trust_remote_code=True
).to(device)
@spaces.GPU
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": system_message}]
print(message)
print(history)
messages.extend(history)
images = []
if message["files"]:
pil_image = Image.open(message["files"][0]).convert("RGB")
images.append(pil_image)
current_user_message = {"role": "user", "content": message["text"]}
messages.append(current_user_message)
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=prompt, images=images, return_tensors="pt").to(device, torch_dtype)
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
demo = gr.ChatInterface(
respond,
type='messages',
multimodal=True,
additional_inputs=[
gr.Textbox(value="You are a helpful and friendly multimodal assistant. You can analyze images and answer questions about them.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
title="Chatbot",
description="Ask me anything or upload an image. This version uses AutoModel and AutoProcessor directly.",
)
if __name__ == "__main__":
demo.launch()