|
|
|
|
|
from collections.abc import Iterator |
|
from threading import Thread |
|
|
|
import gradio as gr |
|
import spaces |
|
import torch |
|
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer |
|
|
|
model_id = "google/gemma-3-12b-it" |
|
processor = AutoProcessor.from_pretrained(model_id, padding_side="left") |
|
model = Gemma3ForConditionalGeneration.from_pretrained( |
|
model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" |
|
) |
|
|
|
|
|
def process_new_user_message(message: dict) -> list[dict]: |
|
return [{"type": "text", "text": message["text"]}, *[{"type": "image", "url": path} for path in message["files"]]] |
|
|
|
|
|
def process_history(history: list[dict]) -> list[dict]: |
|
messages = [] |
|
current_user_content: list[dict] = [] |
|
for item in history: |
|
if item["role"] == "assistant": |
|
if current_user_content: |
|
messages.append({"role": "user", "content": current_user_content}) |
|
current_user_content = [] |
|
messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]}) |
|
else: |
|
content = item["content"] |
|
if isinstance(content, str): |
|
current_user_content.append({"type": "text", "text": content}) |
|
else: |
|
current_user_content.append({"type": "image", "url": content[0]}) |
|
return messages |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]: |
|
messages = [] |
|
if system_prompt: |
|
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) |
|
messages.extend(process_history(history)) |
|
messages.append({"role": "user", "content": process_new_user_message(message)}) |
|
|
|
inputs = processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt", |
|
).to(device=model.device, dtype=torch.bfloat16) |
|
|
|
streamer = TextIteratorStreamer(processor, timeout=60.0, skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
inputs, |
|
streamer=streamer, |
|
max_new_tokens=max_new_tokens, |
|
) |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
output = "" |
|
for delta in streamer: |
|
output += delta |
|
yield output |
|
|
|
|
|
examples = [ |
|
[ |
|
{ |
|
"text": "caption this image", |
|
"files": ["assets/sample-images/01.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "What's the sign says?", |
|
"files": ["assets/sample-images/02.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Compare and contrast the two images.", |
|
"files": ["assets/sample-images/03.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "List all the objects in the image and their colors.", |
|
"files": ["assets/sample-images/04.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Describe the atmosphere of the scene.", |
|
"files": ["assets/sample-images/05.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Write a poem inspired by the visual elements of the images.", |
|
"files": ["assets/sample-images/06-1.png", "assets/sample-images/06-2.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Compose a short musical piece inspired by the visual elements of the images.", |
|
"files": [ |
|
"assets/sample-images/07-1.png", |
|
"assets/sample-images/07-2.png", |
|
"assets/sample-images/07-3.png", |
|
"assets/sample-images/07-4.png", |
|
], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Write a short story about what might have happened in this house.", |
|
"files": ["assets/sample-images/08.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Create a short story based on the sequence of images.", |
|
"files": [ |
|
"assets/sample-images/09-1.png", |
|
"assets/sample-images/09-2.png", |
|
"assets/sample-images/09-3.png", |
|
"assets/sample-images/09-4.png", |
|
"assets/sample-images/09-5.png", |
|
], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Describe the creatures that would live in this world.", |
|
"files": ["assets/sample-images/10.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Read text in the image.", |
|
"files": ["assets/additional-examples/1.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "When is this ticket dated and how much did it cost?", |
|
"files": ["assets/additional-examples/2.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Read the text in the image into markdown.", |
|
"files": ["assets/additional-examples/3.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Evaluate this integral.", |
|
"files": ["assets/additional-examples/4.png"], |
|
} |
|
], |
|
] |
|
|
|
demo = gr.ChatInterface( |
|
fn=run, |
|
type="messages", |
|
textbox=gr.MultimodalTextbox(file_types=["image"], file_count="multiple"), |
|
multimodal=True, |
|
additional_inputs=[ |
|
gr.Textbox(label="System Prompt", value="You are a helpful assistant."), |
|
gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=500), |
|
], |
|
stop_btn=False, |
|
title="Gemma 3 12B it", |
|
description="<img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />", |
|
examples=examples, |
|
run_examples_on_click=False, |
|
cache_examples=False, |
|
css_paths="style.css", |
|
delete_cache=(1800, 1800), |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|