gemma-2-2b-it / app.py
BryanBradfo's picture
logo added
6154477
import os
import queue
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
############################################################
# Model setup (modify as needed)
############################################################
DESCRIPTION = """\
<h1 style="text-align: center;">Hi, I'm Gemma 2 (2B) 👋</h1>
This is a demo of <strong>google/gemma-2-2b-it</strong> fine-tuned for instruction following.
For more details, please check
<a href="https://huggingface.co/blog/gemma2" target="_blank">the post</a>.
👉 Looking for a larger version? Try the 27B in
<a href="https://huggingface.co/chat/models/google/gemma-2-27b-it" target="_blank">HuggingChat</a>
and the 9B in
<a href="https://huggingface.co/spaces/huggingface-projects/gemma-2-9b-it" target="_blank">this Space</a>.
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.config.sliding_window = 4096
model.eval()
############################################################
# Generator function (streaming approach)
############################################################
@spaces.GPU
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
"""Generate text from the model and stream tokens back to the UI."""
conversation = chat_history.copy()
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
try:
for text in streamer:
outputs.append(text)
yield "".join(outputs)
except queue.Empty:
# End of stream; avoid traceback
return
############################################################
# CREATE_INTERFACE function returning a gr.Blocks
############################################################
def create_interface() -> gr.Blocks:
"""
Build a custom Blocks interface containing:
- A Chatbot with user/bot icons
- A ChatInterface that uses the chatbot
- Custom example suggestions with special styling
"""
gemma_css = """
:root {
--gradient-start: #66AEEF; /* lighter top */
--gradient-end: #F0F8FF; /* very light at bottom */
}
/* Overall page & container background gradient */
html, body, .gradio-container {
margin: 0;
padding: 0;
background: linear-gradient(to bottom, var(--gradient-start), var(--gradient-end));
font-family: "Helvetica", sans-serif;
color: #333; /* dark gray for better contrast */
}
/* Make anchor (link) text a clearly visible dark blue */
a, a:visited {
color: #02497A !important;
text-decoration: underline;
}
/* Center the top headings in the description */
.gradio-container h1 {
margin-top: 0.8em;
margin-bottom: 0.5em;
text-align: center;
color: #fff; /* White text on top gradient for pop */
}
/* Chat container background: a very light blue so it's distinct from the outer gradient */
.chatbot, .chatbot .wrap, .chat-interface, .chat-interface .wrap {
background-color: #F8FDFF !important;
}
/* Remove harsh frames around chat messages */
.chatbot .message, .chat-message {
border: none !important;
position: relative;
}
/* Icons for user and bot messages (Chatbot) */
.chatbot .user .chat-avatar {
background: url('user.png') center center no-repeat;
background-size: cover;
}
.chatbot .bot .chat-avatar {
background: url('gemma.png') center center no-repeat;
background-size: cover;
}
/* Icons for user and bot messages (ChatInterface) */
.chat-message.user::before {
content: '';
display: inline-block;
background: url('user.png') center center no-repeat;
background-size: cover;
width: 24px;
height: 24px;
margin-right: 8px;
vertical-align: middle;
}
.chat-message.bot::before {
content: '';
display: inline-block;
background: url('gemma.png') center center no-repeat;
background-size: cover;
width: 24px;
height: 24px;
margin-right: 8px;
vertical-align: middle;
}
/* Chat bubbles (ChatInterface) */
.chat-message.user {
background-color: #0284C7 !important;
color: #FFFFFF !important;
border-radius: 8px;
padding: 8px 12px;
margin: 6px 0;
}
.chat-message.bot {
background-color: #EFF8FF !important;
color: #333 !important;
border-radius: 8px;
padding: 8px 12px;
margin: 6px 0;
}
/* Chat input area */
.chat-input textarea {
background-color: #FFFFFF;
color: #333;
border: 1px solid #66AEEF;
border-radius: 6px;
padding: 8px;
}
/* Sliders & other controls */
form.sliders input[type="range"] {
accent-color: #66AEEF;
}
form.sliders label {
color: #333;
}
.gradio-button, .chat-send-btn {
background-color: #0284C7 !important;
color: #FFFFFF !important;
border-radius: 5px;
border: none;
cursor: pointer;
}
.gradio-button:hover, .chat-send-btn:hover {
background-color: #026FA6 !important;
}
/* Style the example "pill" buttons (ChatInterface) */
.gr-examples {
display: flex !important;
flex-wrap: wrap;
gap: 16px;
justify-content: center;
margin-bottom: 1em !important;
}
.gr-examples button.example {
background-color: #EFF8FF !important;
border: 1px solid #66AEEF !important;
border-radius: 8px !important;
color: #333 !important;
padding: 10px 16px !important;
cursor: pointer !important;
transition: background-color 0.2s !important;
}
.gr-examples button.example:hover {
background-color: #E0F2FF !important;
}
/* Additional spacing / small tweaks */
#duplicate-button {
margin: auto;
background: #1565c0;
border-radius: 100vh;
color: #fff;
}
"""
with gr.Blocks(css=gemma_css) as app:
# A heading or custom markdown
gr.Markdown(DESCRIPTION)
# We can define a custom Gradio Chatbot (if you want both Chatbot and ChatInterface)
chatbot = gr.Chatbot(
label="Gemma Chat (Blocks-based)",
avatar_images=("user.png", "gemma.png"),
height=450,
show_copy_button=True
)
# Then define a ChatInterface that references your generate function
# and optionally reuses the same "chatbot" component if you want.
interface = gr.ChatInterface(
fn=generate,
chatbot=chatbot, # link the Chatbot to the ChatInterface
css=gemma_css, # keep your custom CSS
description="Gemma 2",
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
examples=[
["Hello there! How are you doing?"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
cache_examples=False,
fill_height=True,
)
return app
############################################################
# Main script entry
############################################################
def main():
demo = create_interface()
# Launch the app with queue for concurrency/streaming
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860, debug=True)
if __name__ == "__main__":
main()