File size: 1,937 Bytes
fb1f781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2295e60
fb1f781
 
 
2295e60
 
fb1f781
 
 
 
 
 
 
 
 
2295e60
fb1f781
 
 
2295e60
fb1f781
 
 
 
2295e60
fb1f781
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from datetime import datetime
import os
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image, ImageOps
import torch
from peft import PeftModel
from huggingface_hub import snapshot_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

base_model_name = "HuggingFaceTB/SmolVLM-256M-Instruct"
processor = AutoProcessor.from_pretrained(
    base_model_name, 
    torch_dtype=torch.bfloat16,
    _attn_implementation="flash_attention_2" if device == "cuda" else "eager"
    )

base_model = AutoModelForVision2Seq.from_pretrained(base_model_name, torch_dtype=torch.bfloat16).to(device)


repo_local_path = snapshot_download(
    repo_id="Irina1402/smolvlm-painting-description"
)

model = PeftModel.from_pretrained(base_model, model_id=repo_local_path)
model.eval()



def process_chat(text: str = None, image: Image.Image = None):
    """Process the input and generate a response using SmolVLM."""
    image_data = None

    inputs = []
    if image:
        image_data = image.convert("RGB")
        image_data = ImageOps.exif_transpose(image_data)
        inputs.append({"type": "image"})

    if text:
        inputs.append({"type": "text", "text": text})

    message = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}]

    prompt = processor.apply_chat_template(message, add_generation_prompt=True)

    print(f"Prepared prompt:\n{prompt}")

    processed_inputs = processor(
        text=prompt,
        images=[image_data] if image_data else None,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        generated_ids = model.generate(**processed_inputs, max_new_tokens=50, repetition_penalty=1.2)

    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    assistant_text = generated_text.split("Assistant:", 1)[-1].strip()

    return assistant_text