Spaces:
Sleeping
Sleeping
import os | |
import queue | |
from collections.abc import Iterator | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
############################################################ | |
# Model setup (modify as needed) | |
############################################################ | |
DESCRIPTION = """\ | |
<h1 style="text-align: center;">Hi, I'm Gemma 2 (2B) 👋</h1> | |
This is a demo of <strong>google/gemma-2-2b-it</strong> fine-tuned for instruction following. | |
For more details, please check | |
<a href="https://huggingface.co/blog/gemma2" target="_blank">the post</a>. | |
👉 Looking for a larger version? Try the 27B in | |
<a href="https://huggingface.co/chat/models/google/gemma-2-27b-it" target="_blank">HuggingChat</a> | |
and the 9B in | |
<a href="https://huggingface.co/spaces/huggingface-projects/gemma-2-9b-it" target="_blank">this Space</a>. | |
""" | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model_id = "google/gemma-2-2b-it" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
) | |
model.config.sliding_window = 4096 | |
model.eval() | |
############################################################ | |
# Generator function (streaming approach) | |
############################################################ | |
def generate( | |
message: str, | |
chat_history: list[dict], | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
"""Generate text from the model and stream tokens back to the UI.""" | |
conversation = chat_history.copy() | |
conversation.append({"role": "user", "content": message}) | |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
{"input_ids": input_ids}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=repetition_penalty, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
try: | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
except queue.Empty: | |
# End of stream; avoid traceback | |
return | |
############################################################ | |
# CREATE_INTERFACE function returning a gr.Blocks | |
############################################################ | |
def create_interface() -> gr.Blocks: | |
""" | |
Build a custom Blocks interface containing: | |
- A Chatbot with user/bot icons | |
- A ChatInterface that uses the chatbot | |
- Custom example suggestions with special styling | |
""" | |
gemma_css = """ | |
:root { | |
--gradient-start: #66AEEF; /* lighter top */ | |
--gradient-end: #F0F8FF; /* very light at bottom */ | |
} | |
/* Overall page & container background gradient */ | |
html, body, .gradio-container { | |
margin: 0; | |
padding: 0; | |
background: linear-gradient(to bottom, var(--gradient-start), var(--gradient-end)); | |
font-family: "Helvetica", sans-serif; | |
color: #333; /* dark gray for better contrast */ | |
} | |
/* Make anchor (link) text a clearly visible dark blue */ | |
a, a:visited { | |
color: #02497A !important; | |
text-decoration: underline; | |
} | |
/* Center the top headings in the description */ | |
.gradio-container h1 { | |
margin-top: 0.8em; | |
margin-bottom: 0.5em; | |
text-align: center; | |
color: #fff; /* White text on top gradient for pop */ | |
} | |
/* Chat container background: a very light blue so it's distinct from the outer gradient */ | |
.chatbot, .chatbot .wrap, .chat-interface, .chat-interface .wrap { | |
background-color: #F8FDFF !important; | |
} | |
/* Remove harsh frames around chat messages */ | |
.chatbot .message, .chat-message { | |
border: none !important; | |
position: relative; | |
} | |
/* Icons for user and bot messages (Chatbot) */ | |
.chatbot .user .chat-avatar { | |
background: url('user.png') center center no-repeat; | |
background-size: cover; | |
} | |
.chatbot .bot .chat-avatar { | |
background: url('gemma.png') center center no-repeat; | |
background-size: cover; | |
} | |
/* Icons for user and bot messages (ChatInterface) */ | |
.chat-message.user::before { | |
content: ''; | |
display: inline-block; | |
background: url('user.png') center center no-repeat; | |
background-size: cover; | |
width: 24px; | |
height: 24px; | |
margin-right: 8px; | |
vertical-align: middle; | |
} | |
.chat-message.bot::before { | |
content: ''; | |
display: inline-block; | |
background: url('gemma.png') center center no-repeat; | |
background-size: cover; | |
width: 24px; | |
height: 24px; | |
margin-right: 8px; | |
vertical-align: middle; | |
} | |
/* Chat bubbles (ChatInterface) */ | |
.chat-message.user { | |
background-color: #0284C7 !important; | |
color: #FFFFFF !important; | |
border-radius: 8px; | |
padding: 8px 12px; | |
margin: 6px 0; | |
} | |
.chat-message.bot { | |
background-color: #EFF8FF !important; | |
color: #333 !important; | |
border-radius: 8px; | |
padding: 8px 12px; | |
margin: 6px 0; | |
} | |
/* Chat input area */ | |
.chat-input textarea { | |
background-color: #FFFFFF; | |
color: #333; | |
border: 1px solid #66AEEF; | |
border-radius: 6px; | |
padding: 8px; | |
} | |
/* Sliders & other controls */ | |
form.sliders input[type="range"] { | |
accent-color: #66AEEF; | |
} | |
form.sliders label { | |
color: #333; | |
} | |
.gradio-button, .chat-send-btn { | |
background-color: #0284C7 !important; | |
color: #FFFFFF !important; | |
border-radius: 5px; | |
border: none; | |
cursor: pointer; | |
} | |
.gradio-button:hover, .chat-send-btn:hover { | |
background-color: #026FA6 !important; | |
} | |
/* Style the example "pill" buttons (ChatInterface) */ | |
.gr-examples { | |
display: flex !important; | |
flex-wrap: wrap; | |
gap: 16px; | |
justify-content: center; | |
margin-bottom: 1em !important; | |
} | |
.gr-examples button.example { | |
background-color: #EFF8FF !important; | |
border: 1px solid #66AEEF !important; | |
border-radius: 8px !important; | |
color: #333 !important; | |
padding: 10px 16px !important; | |
cursor: pointer !important; | |
transition: background-color 0.2s !important; | |
} | |
.gr-examples button.example:hover { | |
background-color: #E0F2FF !important; | |
} | |
/* Additional spacing / small tweaks */ | |
#duplicate-button { | |
margin: auto; | |
background: #1565c0; | |
border-radius: 100vh; | |
color: #fff; | |
} | |
""" | |
with gr.Blocks(css=gemma_css) as app: | |
# A heading or custom markdown | |
gr.Markdown(DESCRIPTION) | |
# We can define a custom Gradio Chatbot (if you want both Chatbot and ChatInterface) | |
chatbot = gr.Chatbot( | |
label="Gemma Chat (Blocks-based)", | |
avatar_images=("user.png", "gemma.png"), | |
height=450, | |
show_copy_button=True | |
) | |
# Then define a ChatInterface that references your generate function | |
# and optionally reuses the same "chatbot" component if you want. | |
interface = gr.ChatInterface( | |
fn=generate, | |
chatbot=chatbot, # link the Chatbot to the ChatInterface | |
css=gemma_css, # keep your custom CSS | |
description="Gemma 2", | |
additional_inputs=[ | |
gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
), | |
gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=4.0, | |
step=0.1, | |
value=0.6, | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.9, | |
), | |
gr.Slider( | |
label="Top-k", | |
minimum=1, | |
maximum=1000, | |
step=1, | |
value=50, | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.2, | |
), | |
], | |
examples=[ | |
["Hello there! How are you doing?"], | |
["Can you explain briefly to me what is the Python programming language?"], | |
["Explain the plot of Cinderella in a sentence."], | |
["How many hours does it take a man to eat a Helicopter?"], | |
["Write a 100-word article on 'Benefits of Open-Source in AI research'"], | |
], | |
cache_examples=False, | |
fill_height=True, | |
) | |
return app | |
############################################################ | |
# Main script entry | |
############################################################ | |
def main(): | |
demo = create_interface() | |
# Launch the app with queue for concurrency/streaming | |
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860, debug=True) | |
if __name__ == "__main__": | |
main() | |