File size: 4,525 Bytes
b1558e3
 
 
2013b5a
15bd5c0
b1558e3
 
1ceaf3e
b1558e3
 
1ceaf3e
 
b1558e3
 
 
bd1d5de
b1558e3
 
1ceaf3e
15bd5c0
b1558e3
 
 
 
1ceaf3e
b1558e3
 
 
1ceaf3e
 
 
15bd5c0
 
 
 
fb4e39c
 
 
 
 
 
 
 
 
 
 
15bd5c0
 
 
 
000548f
15bd5c0
 
 
 
 
 
 
 
 
 
fb8294c
15bd5c0
 
 
 
 
 
fb4e39c
15bd5c0
 
 
 
 
fb8294c
15bd5c0
 
 
fb4e39c
 
15bd5c0
 
 
 
 
20ff6a1
 
 
 
 
 
15bd5c0
 
 
 
 
 
 
 
 
 
 
 
 
50a8e40
15bd5c0
 
fb4e39c
 
fb8294c
15bd5c0
 
 
 
 
 
 
 
 
 
000548f
15bd5c0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import bitsandbytes

MODEL_ID = "HuggingFaceTB/SmolLM3-3B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading tokenizer & model…")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to(DEVICE)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    load_in_8bit=True,  # or try load_in_4bit=True
    device_map=DEVICE
)

#########

# print("Loading tokenizer & model…")
# import gradio as gr
# from transformers import AutoTokenizer
# from optimum.onnxruntime import ORTModelForCausalLM

# MODEL_ID = "HuggingFaceTB/SmolLM3-3B"
# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# model = ORTModelForCausalLM.from_pretrained(MODEL_ID, export=True, quantize=True)

#########


# -------------------------------------------------
# Optional tool(s)
# -------------------------------------------------
# TOOLS = [{
#     "name": "get_weather",
#     "description": "Get the current weather in a given city",
#     "parameters": {
#         "type": "object",
#         "properties": {
#             "city": {"type": "string", "description": "City name"}
#         },
#         "required": ["city"]
#     }
# }]

# -------------------------------------------------
# Helpers
# -------------------------------------------------

def build_messages(history, enable_thinking: bool):
    """Convert Gradio history to the chat template."""
    messages = []
    for h in history:
        messages.append({"role": h["role"], "content": h["content"]})
    # Add system instruction for mode
    system_flag = "/think" if enable_thinking else "/no_think"
    messages.insert(0, {"role": "system", "content": system_flag})
    return messages

def chat_fn(history, enable_thinking, temperature, top_p, top_k, repetition_penalty, max_new_tokens):
    """Generate a streaming response."""
    messages = build_messages(history, enable_thinking)
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        # xml_tools=TOOLS
    )
    inputs = tokenizer(text, return_tensors="pt").to(DEVICE)

    streamer = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        pad_token_id=tokenizer.eos_token_id,
        streamer=None          # we'll yield manually
    )
    output_ids = streamer[0][len(inputs.input_ids[0]):]
    response = tokenizer.decode(output_ids, skip_special_tokens=True)
    if isinstance(response, str): 
        response = response.replace('<think>',"# &lt;think&gt;").replace('</think>',"&lt;/think&gt;")
    elif isinstance(response,list):
        response = [paper.replace('<think>',"# &lt;think&gt;").replace('</think>',"&lt;/think&gt;") for paper in response]
    else:
        raise ValueError("Tokenizer response seems malformed; Not a string, nor a list?!?!")

    # streaming char-by-char
    history.append({"role": "assistant", "content": ""})
    for ch in response:
        history[-1]["content"] += ch
        yield history

# -------------------------------------------------
# Blocks UI
# -------------------------------------------------
with gr.Blocks(title="SmolLM3-3B Chat") as demo:
    gr.Markdown("## 🤖 SmolLM3-3B Chatbot (Streaming)")
    with gr.Row():
        enable_think = gr.Checkbox(label="Enable Extended Thinking (/think)", value=True)
        temperature = gr.Slider(0.0, 1.0, value=0.6, label="Temperature")
        top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top-p")
        top_k = gr.Slider(1,40,value=20,label="Top_k")
        repetition_penalty = gr.Slider(1.0,1.4,value=1.1,label="Repetition_Penalty")
        max_new_tokens = gr.Slider(1000,32768,value=32768,label="Max_New_Tokens")
    chatbot = gr.Chatbot(type="messages")
    msg = gr.Textbox(placeholder="Type your message here…", lines=1)
    clear = gr.Button("Clear")

    def user_fn(user_msg, history):
        return "", history + [{"role": "user", "content": user_msg}]

    msg.submit(
        user_fn, [msg, chatbot], [msg, chatbot], queue=False
    ).then(
        chat_fn, [chatbot, enable_think, temperature, top_p, top_k, repetition_penalty, max_new_tokens], chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue().launch()