david-thrower's picture
Update app.py
fb4e39c verified
raw
history blame
3.48 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "HuggingFaceTB/SmolLM3-3B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading tokenizer & model…")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to(DEVICE)
# -------------------------------------------------
# Optional tool(s)
# -------------------------------------------------
# TOOLS = [{
# "name": "get_weather",
# "description": "Get the current weather in a given city",
# "parameters": {
# "type": "object",
# "properties": {
# "city": {"type": "string", "description": "City name"}
# },
# "required": ["city"]
# }
# }]
# -------------------------------------------------
# Helpers
# -------------------------------------------------
def build_messages(history, enable_thinking: bool):
"""Convert Gradio history to the chat template."""
messages = []
for h in history:
messages.append({"role": h["role"], "content": h["content"]})
# Add system instruction for mode
system_flag = "/think" if enable_thinking else "/no_think"
messages.insert(0, {"role": "system", "content": system_flag})
return messages
def chat_fn(history, enable_thinking, temperature, top_p, top_k, repetition_penalty):
"""Generate a streaming response."""
messages = build_messages(history, enable_thinking)
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
# xml_tools=TOOLS
)
inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
streamer = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
streamer=None # we'll yield manually
)
output_ids = streamer[0][len(inputs.input_ids[0]):]
response = tokenizer.decode(output_ids, skip_special_tokens=True)
# streaming char-by-char
history.append({"role": "assistant", "content": ""})
for ch in response:
history[-1]["content"] += ch
yield history
# -------------------------------------------------
# Blocks UI
# -------------------------------------------------
with gr.Blocks(title="SmolLM3-3B Chat") as demo:
gr.Markdown("## 🤖 SmolLM3-3B Chatbot (Streaming)")
with gr.Row():
enable_think = gr.Checkbox(label="Enable Extended Thinking (/think)", value=False)
temperature = gr.Slider(0.0, 1.0, value=0.6, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top-p")
top_k = gr.Slider(1,40,value=20,label="Top_k")
repetition_penalty = gr.Slider(1.0,1.4,value=1.1,label="Repetition_Penalty")
chatbot = gr.Chatbot(type="messages")
msg = gr.Textbox(placeholder="Type your message here…", lines=1)
clear = gr.Button("Clear")
def user_fn(user_msg, history):
return "", history + [{"role": "user", "content": user_msg}]
msg.submit(
user_fn, [msg, chatbot], [msg, chatbot], queue=False
).then(
chat_fn, [chatbot, enable_think, temperature, top_p, top_k, repetition_penalty], chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue().launch()