File size: 1,412 Bytes
a509d7e
 
 
 
 
 
 
b9114d1
a509d7e
b9114d1
a509d7e
b9114d1
a509d7e
b9114d1
 
a509d7e
9cafba3
a509d7e
b9114d1
 
a509d7e
 
b9114d1
a509d7e
b9114d1
a509d7e
b9114d1
a509d7e
b9114d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a509d7e
 
b9114d1
a509d7e
b9114d1
a509d7e
b9114d1
a509d7e
b9114d1
 
a509d7e
b9114d1
a509d7e
b9114d1
 
 
 
 
 
a509d7e
b9114d1
a509d7e
b9114d1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
---
library_name: transformers
tags:
- trl
- sft
---

# tinygemma3 with vision

This is trained on [CIFAR-10](https://huggingface.co/datasets/uoft-cs/cifar10) dataset.

How to use:

```py
from transformers import AutoModelForImageTextToText, AutoProcessor

model_id = "ngxson/tinygemma3_cifar"

model = AutoModelForImageTextToText.from_pretrained(model_id).to("cuda")
processor = AutoProcessor.from_pretrained(model_id)


#####################

from datasets import load_dataset, Dataset

ds_full = load_dataset("uoft-cs/cifar10")

def ex_to_msg(ex):
    txt = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "What is this:"},
                {"type": "image"}
            ]
        }
    ]
    img = ex["img"]
    return {
        "messages": txt,
        "images": [img],
    }


#####################

test_idx = 0

test_msg = ex_to_msg(ds_full["train"][test_idx])

test_txt = processor.apply_chat_template(test_msg["messages"], tokenize=False, add_generation_prompt=True)
test_input = processor(text=test_txt, images=test_msg["images"], return_tensors="pt").to(model.device)

#####################

generated_ids = model.generate(**test_input, do_sample=False, max_new_tokens=1)
generated_texts = processor.batch_decode(
    generated_ids,
    skip_special_tokens=True,
)
print(generated_texts)

# expected answer for test_idx = 0 is "airplane"

```