File size: 3,383 Bytes
1460a51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from transformers import AutoProcessor, AutoModelForVision2Seq, Qwen2VLForConditionalGeneration
import gradio as gr
from PIL import Image
model2 = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype="auto", # or torch.bfloat16
# attn_implementation="flash_attention_2",
device_map="auto",
)
# default processer
processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct")
# Game rules
GAME_RULES = """In diesem Bild sehen Sie drei Farbenraster.
In der folgenden Beschreibung wird genau eines der Raster beschrieben.
Bitte geben Sie an, ob sich der Sprecher auf das linke, mittlere oder rechte Raster bezieht.
Antworten Sie auf Deutsch.
"""
# Load one image
IMAGE_OPTIONS = {
"Grid 1": "example1.jpg",
"Grid 2": "example2.jpg",
"Grid 3": "example3.jpg",
"Grid 4": "example4.jpg",
"Grid 5": "example5.jpg"
}
# Function to run model
def play_game(selected_image_label, user_prompt):
user_prompt = user_prompt #or "Bitte beschreiben Sie das Gitter."
selected_image_path = IMAGE_OPTIONS[selected_image_label]
selected_image = Image.open(selected_image_path)
# Build messages
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": selected_image},
{"type": "text", "text": GAME_RULES + "\n" + user_prompt},
],
}
]
# prepare input
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
images=[selected_image],
return_tensors="pt",
).to(model2.device)
# Run generation normally
generated_ids = model2.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return output_text
# Gradio App
with gr.Blocks() as demo:
with gr.Column():
image_selector = gr.Dropdown(
choices=list(IMAGE_OPTIONS.keys()),
value="Grid 1",
label="WΓ€hlen Sie ein Bild"
)
image_display = gr.Image(
value=Image.open(IMAGE_OPTIONS["Grid 1"]),
label="Ihr Bild",
interactive=False,
type="pil"
)
prompt_input = gr.Textbox(
value="Beschreiben Sie das Farbenraster...",
label="Ihre Beschreibung"
)
output_text = gr.Textbox(label="Antwort des Modells")
play_button = gr.Button("Starte das Spiel")
def update_image(selected_label):
selected_path = IMAGE_OPTIONS[selected_label]
return Image.open(selected_path)
# When user changes selection, update image
image_selector.change(
fn=update_image,
inputs=[image_selector],
outputs=image_display
)
# When user clicks play, send inputs to model
play_button.click(
fn=play_game,
inputs=[image_selector, prompt_input],
outputs=output_text
)
demo.launch(share=True, server_port=4879)
|