File size: 3,260 Bytes
1460a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25f6231
 
 
1460a51
 
 
 
25f6231
 
 
 
1460a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25f6231
 
1460a51
 
 
25f6231
f295f0b
1460a51
 
 
 
 
7f8c4d6
 
1460a51
 
25f6231
1460a51
25f6231
1460a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
984283d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from transformers import AutoProcessor, AutoModelForVision2Seq, Qwen2VLForConditionalGeneration
import gradio as gr
from PIL import Image

model2 = Qwen2VLForConditionalGeneration.from_pretrained(
            "Qwen/Qwen2-VL-7B-Instruct",
            torch_dtype="auto",  # or torch.bfloat16
            #  attn_implementation="flash_attention_2",
            device_map="auto",
        )

        # default processer
processor = AutoProcessor.from_pretrained(
            "Qwen/Qwen2-VL-7B-Instruct")


# Game rules
GAME_RULES = """In this image you can see three color grids. In the following dialogue, the speaker will describe exactly one of the grids. 
Please indicate to me whether he refers to the 
left, middle, or right grid.
"""

# Load one image
IMAGE_OPTIONS = {
    "Image 1": "example1.jpg",
    "Image 2": "example2.jpg",
    "Image 3": "example3.png",
    "Image 4": "example4.jpg"
}

# Function to run model
def play_game(selected_image_label, user_prompt):
    user_prompt = user_prompt #or "Bitte beschreiben Sie das Gitter."

    selected_image_path = IMAGE_OPTIONS[selected_image_label]
    selected_image = Image.open(selected_image_path)

    # Build messages 
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": selected_image},
                {"type": "text", "text": GAME_RULES + "\n" + user_prompt},
            ],
        }
    ]

    # prepare input
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(
        text=[text],
        images=[selected_image],
        return_tensors="pt",
    ).to(model2.device)

    # Run generation normally
    generated_ids = model2.generate(**inputs, max_new_tokens=512)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]

    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    
    return output_text


# Gradio App
with gr.Blocks() as demo:
    with gr.Column():
        image_selector = gr.Dropdown(
            choices=list(IMAGE_OPTIONS.keys()),
            value="Image 2",
            label="Choose an image"
        )

        image_display = gr.Image(
            value=Image.open(IMAGE_OPTIONS["Image 2"]),
            label="Image",
            interactive=False,
            type="pil"
        )

        prompt_input = gr.Textbox(
            value="Description",
            label="Your description"
        )

        output_text = gr.Textbox(label="Model's response")

        play_button = gr.Button("start the game")

        def update_image(selected_label):
            selected_path = IMAGE_OPTIONS[selected_label]
            return Image.open(selected_path)

        # When user changes selection, update image
        image_selector.change(
            fn=update_image,
            inputs=[image_selector],
            outputs=image_display
        )

        # When user clicks play, send inputs to model
        play_button.click(
            fn=play_game,
            inputs=[image_selector, prompt_input],
            outputs=output_text
        )

demo.launch()