ManarAli commited on
Commit
1460a51
Β·
verified Β·
1 Parent(s): 109ec9f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForVision2Seq, Qwen2VLForConditionalGeneration
2
+ import gradio as gr
3
+ from PIL import Image
4
+
5
+ model2 = Qwen2VLForConditionalGeneration.from_pretrained(
6
+ "Qwen/Qwen2-VL-7B-Instruct",
7
+ torch_dtype="auto", # or torch.bfloat16
8
+ # attn_implementation="flash_attention_2",
9
+ device_map="auto",
10
+ )
11
+
12
+ # default processer
13
+ processor = AutoProcessor.from_pretrained(
14
+ "Qwen/Qwen2-VL-7B-Instruct")
15
+
16
+
17
+ # Game rules
18
+ GAME_RULES = """In diesem Bild sehen Sie drei Farbenraster.
19
+ In der folgenden Beschreibung wird genau eines der Raster beschrieben.
20
+ Bitte geben Sie an, ob sich der Sprecher auf das linke, mittlere oder rechte Raster bezieht.
21
+ Antworten Sie auf Deutsch.
22
+ """
23
+
24
+ # Load one image
25
+ IMAGE_OPTIONS = {
26
+ "Grid 1": "example1.jpg",
27
+ "Grid 2": "example2.jpg",
28
+ "Grid 3": "example3.jpg",
29
+ "Grid 4": "example4.jpg",
30
+ "Grid 5": "example5.jpg"
31
+ }
32
+
33
+ # Function to run model
34
+ def play_game(selected_image_label, user_prompt):
35
+ user_prompt = user_prompt #or "Bitte beschreiben Sie das Gitter."
36
+
37
+ selected_image_path = IMAGE_OPTIONS[selected_image_label]
38
+ selected_image = Image.open(selected_image_path)
39
+
40
+ # Build messages
41
+ messages = [
42
+ {
43
+ "role": "user",
44
+ "content": [
45
+ {"type": "image", "image": selected_image},
46
+ {"type": "text", "text": GAME_RULES + "\n" + user_prompt},
47
+ ],
48
+ }
49
+ ]
50
+
51
+ # prepare input
52
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
53
+ inputs = processor(
54
+ text=[text],
55
+ images=[selected_image],
56
+ return_tensors="pt",
57
+ ).to(model2.device)
58
+
59
+ # Run generation normally
60
+ generated_ids = model2.generate(**inputs, max_new_tokens=512)
61
+ generated_ids_trimmed = [
62
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
63
+ ]
64
+
65
+ output_text = processor.batch_decode(
66
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
67
+ )[0]
68
+
69
+ return output_text
70
+
71
+
72
+ # Gradio App
73
+ with gr.Blocks() as demo:
74
+ with gr.Column():
75
+ image_selector = gr.Dropdown(
76
+ choices=list(IMAGE_OPTIONS.keys()),
77
+ value="Grid 1",
78
+ label="WΓ€hlen Sie ein Bild"
79
+ )
80
+
81
+ image_display = gr.Image(
82
+ value=Image.open(IMAGE_OPTIONS["Grid 1"]),
83
+ label="Ihr Bild",
84
+ interactive=False,
85
+ type="pil"
86
+ )
87
+
88
+ prompt_input = gr.Textbox(
89
+ value="Beschreiben Sie das Farbenraster...",
90
+ label="Ihre Beschreibung"
91
+ )
92
+
93
+ output_text = gr.Textbox(label="Antwort des Modells")
94
+
95
+ play_button = gr.Button("Starte das Spiel")
96
+
97
+ def update_image(selected_label):
98
+ selected_path = IMAGE_OPTIONS[selected_label]
99
+ return Image.open(selected_path)
100
+
101
+ # When user changes selection, update image
102
+ image_selector.change(
103
+ fn=update_image,
104
+ inputs=[image_selector],
105
+ outputs=image_display
106
+ )
107
+
108
+ # When user clicks play, send inputs to model
109
+ play_button.click(
110
+ fn=play_game,
111
+ inputs=[image_selector, prompt_input],
112
+ outputs=output_text
113
+ )
114
+
115
+ demo.launch(share=True, server_port=4879)