import { AutoTokenizer, AutoModelForCausalLM, TextStreamer, InterruptableStoppingCriteria, } from "@huggingface/transformers"; /** * Helper function to perform feature detection for WebGPU */ async function check() { try { const adapter = await navigator.gpu.requestAdapter(); if (!adapter) { throw new Error("WebGPU is not supported (no adapter found)"); } if (!adapter.features.has("shader-f16")) { throw new Error("shader-f16 is not supported in this browser"); } } catch (e) { self.postMessage({ status: "error", data: e.toString(), }); } } /** * This class uses the Singleton pattern to enable lazy-loading of the pipeline */ class TextGenerationPipeline { static model_id = "HuggingFaceTB/SmolLM3-3B-ONNX"; static async getInstance(progress_callback = null) { this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, { progress_callback, }); this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, { dtype: "q4f16", device: "webgpu", progress_callback, }); return Promise.all([this.tokenizer, this.model]); } } const stopping_criteria = new InterruptableStoppingCriteria(); let past_key_values_cache = null; async function generate({ messages, reasonEnabled }) { const [tokenizer, model] = await TextGenerationPipeline.getInstance(); const inputs = tokenizer.apply_chat_template(messages, { enable_thinking: reasonEnabled, add_generation_prompt: true, return_dict: true, }); const [START_THINKING_TOKEN_ID, END_THINKING_TOKEN_ID] = tokenizer.encode( "", { add_special_tokens: false }, ); let state = "answering"; // 'thinking' or 'answering' let startTime; let numTokens = 0; let tps; const token_callback_function = (tokens) => { startTime ??= performance.now(); if (numTokens++ > 0) { tps = (numTokens / (performance.now() - startTime)) * 1000; } switch (Number(tokens[0])) { case START_THINKING_TOKEN_ID: state = "thinking"; break; case END_THINKING_TOKEN_ID: state = "answering"; break; } }; const callback_function = (output) => { self.postMessage({ status: "update", output, tps, numTokens, state, }); }; const streamer = new TextStreamer(tokenizer, { skip_prompt: true, skip_special_tokens: true, callback_function, token_callback_function, }); // Tell the main thread we are starting self.postMessage({ status: "start" }); const { past_key_values, sequences } = await model.generate({ ...inputs, past_key_values: past_key_values_cache, // Sampling do_sample: !reasonEnabled, repetition_penalty: reasonEnabled ? 1.1 : undefined, top_k: 3, max_new_tokens: reasonEnabled ? 4096 : 1024, streamer, stopping_criteria, return_dict_in_generate: true, }); past_key_values_cache = past_key_values; const decoded = tokenizer.batch_decode(sequences, { skip_special_tokens: true, }); // Send the output back to the main thread self.postMessage({ status: "complete", output: decoded, }); } async function load() { self.postMessage({ status: "loading", data: "Loading model...", }); // Load the pipeline and save it for future use. const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => { // We also add a progress callback to the pipeline so that we can // track model loading. self.postMessage(x); }); self.postMessage({ status: "loading", data: "Compiling shaders and warming up model...", }); // Run model with dummy input to compile shaders const inputs = tokenizer("a"); await model.generate({ ...inputs, max_new_tokens: 1 }); self.postMessage({ status: "ready" }); } // Listen for messages from the main thread self.addEventListener("message", async (e) => { const { type, data } = e.data; switch (type) { case "check": check(); break; case "load": load(); break; case "generate": stopping_criteria.reset(); generate(data); break; case "interrupt": stopping_criteria.interrupt(); break; case "reset": past_key_values_cache = null; stopping_criteria.reset(); break; } });