File size: 4,783 Bytes
0a424dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import time
import torch

# Load the SmolVLM model and processor
print("πŸ”§ Loading SmolVLM model...")
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct-250M")
model = AutoModelForVision2Seq.from_pretrained(
    "HuggingFaceTB/SmolVLM-Instruct-250M", 
    torch_dtype=torch.bfloat16,
    device_map="auto"  # Automatically handles CPU/GPU placement
)
print("βœ… Model loaded successfully!")

def model_inference(input_dict, history):
    """Process multimodal input and generate response"""
    text = input_dict["text"]
    
    # Handle image input
    if len(input_dict["files"]) > 1:
        images = [load_image(image) for image in input_dict["files"]]
    elif len(input_dict["files"]) == 1:
        images = [load_image(input_dict["files"][0])]   
    else:
        images = []
    
    # Validation
    if text == "" and not images:
        raise gr.Error("Please input a query and optionally image(s).")
    
    if text == "" and images:
        raise gr.Error("Please input a text query along with the image(s).")
    
    # Prepare the conversation format
    resulting_messages = [
        {
            "role": "user",
            "content": [{"type": "image"} for _ in range(len(images))] + [
                {"type": "text", "text": text}
            ]
        }
    ]
    
    try:
        # Apply chat template and process inputs
        prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
        inputs = processor(text=prompt, images=images if images else None, return_tensors="pt")
        
        # Move to appropriate device
        device = next(model.parameters()).device
        inputs = {k: v.to(device) if v is not None else v for k, v in inputs.items()}
        
        # Set up streaming generation
        streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
        generation_kwargs = dict(
            inputs,
            streamer=streamer,
            max_new_tokens=500,
            min_new_tokens=10,
            no_repeat_ngram_size=2,
            do_sample=True,
            temperature=0.7,
            top_p=0.9
        )
        
        # Start generation in separate thread
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
        
        # Stream the response
        yield "Thinking..."
        buffer = ""
        
        for new_text in streamer:
            buffer += new_text
            time.sleep(0.02)  # Small delay for smooth streaming
            yield buffer
            
    except Exception as e:
        yield f"❌ Error generating response: {str(e)}"

# Example prompts and images for demonstration
examples = [
    [{"text": "What do you see in this image?", "files": []}],
    [{"text": "Describe the colors and objects in this image in detail.", "files": []}],
    [{"text": "What is the mood or atmosphere of this image?", "files": []}],
    [{"text": "Are there any people in this image? What are they doing?", "files": []}],
    [{"text": "What text can you read in this image?", "files": []}],
    [{"text": "Count the number of objects you can see.", "files": []}],
]

# Create the Gradio interface using ChatInterface
demo = gr.ChatInterface(
    fn=model_inference,
    title="πŸ” SmolVLM Vision Chat",
    description="""
    Chat with **SmolVLM-256M**, a compact but powerful vision-language model! 
    
    **How to use:**
    1. Upload one or more images using the πŸ“Ž button
    2. Ask questions about the images
    3. Get detailed AI-generated descriptions and answers
    
    **Example questions:**
    - "What do you see in this image?"
    - "Describe the colors and composition"
    - "What text is visible in this image?"
    - "Count the objects in this image"
    
    This model can analyze photos, diagrams, documents, artwork, and more!
    """,
    examples=examples,
    textbox=gr.MultimodalTextbox(
        label="πŸ’¬ Ask about your images...", 
        file_types=["image"], 
        file_count="multiple",
        placeholder="Upload images and ask questions about them!"
    ),
    stop_btn="⏹️ Stop Generation",
    multimodal=True,
    cache_examples=False,
    theme=gr.themes.Soft(),
    css="""
    .gradio-container {
        max-width: 1000px !important;
    }
    .chat-message {
        border-radius: 10px !important;
    }
    """
)

if __name__ == "__main__":
    print("πŸš€ Launching SmolVLM Chat Interface...")
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_error=True
    )