File size: 3,383 Bytes
1460a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from transformers import AutoProcessor, AutoModelForVision2Seq, Qwen2VLForConditionalGeneration
import gradio as gr
from PIL import Image

model2 = Qwen2VLForConditionalGeneration.from_pretrained(
            "Qwen/Qwen2-VL-7B-Instruct",
            torch_dtype="auto",  # or torch.bfloat16
            #  attn_implementation="flash_attention_2",
            device_map="auto",
        )

        # default processer
processor = AutoProcessor.from_pretrained(
            "Qwen/Qwen2-VL-7B-Instruct")


# Game rules
GAME_RULES = """In diesem Bild sehen Sie drei Farbenraster.
In der folgenden Beschreibung wird genau eines der Raster beschrieben.
Bitte geben Sie an, ob sich der Sprecher auf das linke, mittlere oder rechte Raster bezieht.
Antworten Sie auf Deutsch.
"""

# Load one image
IMAGE_OPTIONS = {
    "Grid 1": "example1.jpg",
    "Grid 2": "example2.jpg",
    "Grid 3": "example3.jpg",
    "Grid 4": "example4.jpg",
    "Grid 5": "example5.jpg"
}

# Function to run model
def play_game(selected_image_label, user_prompt):
    user_prompt = user_prompt #or "Bitte beschreiben Sie das Gitter."

    selected_image_path = IMAGE_OPTIONS[selected_image_label]
    selected_image = Image.open(selected_image_path)

    # Build messages 
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": selected_image},
                {"type": "text", "text": GAME_RULES + "\n" + user_prompt},
            ],
        }
    ]

    # prepare input
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(
        text=[text],
        images=[selected_image],
        return_tensors="pt",
    ).to(model2.device)

    # Run generation normally
    generated_ids = model2.generate(**inputs, max_new_tokens=512)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]

    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    
    return output_text


# Gradio App
with gr.Blocks() as demo:
    with gr.Column():
        image_selector = gr.Dropdown(
            choices=list(IMAGE_OPTIONS.keys()),
            value="Grid 1",
            label="WΓ€hlen Sie ein Bild"
        )

        image_display = gr.Image(
            value=Image.open(IMAGE_OPTIONS["Grid 1"]),
            label="Ihr Bild",
            interactive=False,
            type="pil"
        )

        prompt_input = gr.Textbox(
            value="Beschreiben Sie das Farbenraster...",
            label="Ihre Beschreibung"
        )

        output_text = gr.Textbox(label="Antwort des Modells")

        play_button = gr.Button("Starte das Spiel")

        def update_image(selected_label):
            selected_path = IMAGE_OPTIONS[selected_label]
            return Image.open(selected_path)

        # When user changes selection, update image
        image_selector.change(
            fn=update_image,
            inputs=[image_selector],
            outputs=image_display
        )

        # When user clicks play, send inputs to model
        play_button.click(
            fn=play_game,
            inputs=[image_selector, prompt_input],
            outputs=output_text
        )

demo.launch(share=True, server_port=4879)