SmolLM3-3B-WebGPU / src /worker.js
Xenova's picture
Xenova HF Staff
Super-squash branch 'main' using huggingface_hub
a0c1ef5 verified
import {
AutoTokenizer,
AutoModelForCausalLM,
TextStreamer,
InterruptableStoppingCriteria,
} from "@huggingface/transformers";
/**
* Helper function to perform feature detection for WebGPU
*/
async function check() {
try {
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
throw new Error("WebGPU is not supported (no adapter found)");
}
if (!adapter.features.has("shader-f16")) {
throw new Error("shader-f16 is not supported in this browser");
}
} catch (e) {
self.postMessage({
status: "error",
data: e.toString(),
});
}
}
/**
* This class uses the Singleton pattern to enable lazy-loading of the pipeline
*/
class TextGenerationPipeline {
static model_id = "HuggingFaceTB/SmolLM3-3B-ONNX";
static async getInstance(progress_callback = null) {
this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
progress_callback,
});
this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, {
dtype: "q4f16",
device: "webgpu",
progress_callback,
});
return Promise.all([this.tokenizer, this.model]);
}
}
const stopping_criteria = new InterruptableStoppingCriteria();
let past_key_values_cache = null;
async function generate({ messages, reasonEnabled }) {
const [tokenizer, model] = await TextGenerationPipeline.getInstance();
const inputs = tokenizer.apply_chat_template(messages, {
enable_thinking: reasonEnabled,
add_generation_prompt: true,
return_dict: true,
});
const [START_THINKING_TOKEN_ID, END_THINKING_TOKEN_ID] = tokenizer.encode(
"<think></think>",
{ add_special_tokens: false },
);
let state = "answering"; // 'thinking' or 'answering'
let startTime;
let numTokens = 0;
let tps;
const token_callback_function = (tokens) => {
startTime ??= performance.now();
if (numTokens++ > 0) {
tps = (numTokens / (performance.now() - startTime)) * 1000;
}
switch (Number(tokens[0])) {
case START_THINKING_TOKEN_ID:
state = "thinking";
break;
case END_THINKING_TOKEN_ID:
state = "answering";
break;
}
};
const callback_function = (output) => {
self.postMessage({
status: "update",
output,
tps,
numTokens,
state,
});
};
const streamer = new TextStreamer(tokenizer, {
skip_prompt: true,
skip_special_tokens: true,
callback_function,
token_callback_function,
});
// Tell the main thread we are starting
self.postMessage({ status: "start" });
const { past_key_values, sequences } = await model.generate({
...inputs,
past_key_values: past_key_values_cache,
// Sampling
do_sample: !reasonEnabled,
repetition_penalty: reasonEnabled ? 1.1 : undefined,
top_k: 3,
max_new_tokens: reasonEnabled ? 4096 : 1024,
streamer,
stopping_criteria,
return_dict_in_generate: true,
});
past_key_values_cache = past_key_values;
const decoded = tokenizer.batch_decode(sequences, {
skip_special_tokens: true,
});
// Send the output back to the main thread
self.postMessage({
status: "complete",
output: decoded,
});
}
async function load() {
self.postMessage({
status: "loading",
data: "Loading model...",
});
// Load the pipeline and save it for future use.
const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => {
// We also add a progress callback to the pipeline so that we can
// track model loading.
self.postMessage(x);
});
self.postMessage({
status: "loading",
data: "Compiling shaders and warming up model...",
});
// Run model with dummy input to compile shaders
const inputs = tokenizer("a");
await model.generate({ ...inputs, max_new_tokens: 1 });
self.postMessage({ status: "ready" });
}
// Listen for messages from the main thread
self.addEventListener("message", async (e) => {
const { type, data } = e.data;
switch (type) {
case "check":
check();
break;
case "load":
load();
break;
case "generate":
stopping_criteria.reset();
generate(data);
break;
case "interrupt":
stopping_criteria.interrupt();
break;
case "reset":
past_key_values_cache = null;
stopping_criteria.reset();
break;
}
});