File size: 4,783 Bytes
0a424dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
# Load the SmolVLM model and processor
print("π§ Loading SmolVLM model...")
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct-250M")
model = AutoModelForVision2Seq.from_pretrained(
"HuggingFaceTB/SmolVLM-Instruct-250M",
torch_dtype=torch.bfloat16,
device_map="auto" # Automatically handles CPU/GPU placement
)
print("β
Model loaded successfully!")
def model_inference(input_dict, history):
"""Process multimodal input and generate response"""
text = input_dict["text"]
# Handle image input
if len(input_dict["files"]) > 1:
images = [load_image(image) for image in input_dict["files"]]
elif len(input_dict["files"]) == 1:
images = [load_image(input_dict["files"][0])]
else:
images = []
# Validation
if text == "" and not images:
raise gr.Error("Please input a query and optionally image(s).")
if text == "" and images:
raise gr.Error("Please input a text query along with the image(s).")
# Prepare the conversation format
resulting_messages = [
{
"role": "user",
"content": [{"type": "image"} for _ in range(len(images))] + [
{"type": "text", "text": text}
]
}
]
try:
# Apply chat template and process inputs
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=images if images else None, return_tensors="pt")
# Move to appropriate device
device = next(model.parameters()).device
inputs = {k: v.to(device) if v is not None else v for k, v in inputs.items()}
# Set up streaming generation
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=500,
min_new_tokens=10,
no_repeat_ngram_size=2,
do_sample=True,
temperature=0.7,
top_p=0.9
)
# Start generation in separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream the response
yield "Thinking..."
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.02) # Small delay for smooth streaming
yield buffer
except Exception as e:
yield f"β Error generating response: {str(e)}"
# Example prompts and images for demonstration
examples = [
[{"text": "What do you see in this image?", "files": []}],
[{"text": "Describe the colors and objects in this image in detail.", "files": []}],
[{"text": "What is the mood or atmosphere of this image?", "files": []}],
[{"text": "Are there any people in this image? What are they doing?", "files": []}],
[{"text": "What text can you read in this image?", "files": []}],
[{"text": "Count the number of objects you can see.", "files": []}],
]
# Create the Gradio interface using ChatInterface
demo = gr.ChatInterface(
fn=model_inference,
title="π SmolVLM Vision Chat",
description="""
Chat with **SmolVLM-256M**, a compact but powerful vision-language model!
**How to use:**
1. Upload one or more images using the π button
2. Ask questions about the images
3. Get detailed AI-generated descriptions and answers
**Example questions:**
- "What do you see in this image?"
- "Describe the colors and composition"
- "What text is visible in this image?"
- "Count the objects in this image"
This model can analyze photos, diagrams, documents, artwork, and more!
""",
examples=examples,
textbox=gr.MultimodalTextbox(
label="π¬ Ask about your images...",
file_types=["image"],
file_count="multiple",
placeholder="Upload images and ask questions about them!"
),
stop_btn="βΉοΈ Stop Generation",
multimodal=True,
cache_examples=False,
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1000px !important;
}
.chat-message {
border-radius: 10px !important;
}
"""
)
if __name__ == "__main__":
print("π Launching SmolVLM Chat Interface...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
) |