Spaces:
Running
Running
import { | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
TextStreamer, | |
InterruptableStoppingCriteria, | |
} from "@huggingface/transformers"; | |
/** | |
* Helper function to perform feature detection for WebGPU | |
*/ | |
async function check() { | |
try { | |
const adapter = await navigator.gpu.requestAdapter(); | |
if (!adapter) { | |
throw new Error("WebGPU is not supported (no adapter found)"); | |
} | |
if (!adapter.features.has("shader-f16")) { | |
throw new Error("shader-f16 is not supported in this browser"); | |
} | |
} catch (e) { | |
self.postMessage({ | |
status: "error", | |
data: e.toString(), | |
}); | |
} | |
} | |
/** | |
* This class uses the Singleton pattern to enable lazy-loading of the pipeline | |
*/ | |
class TextGenerationPipeline { | |
static model_id = "HuggingFaceTB/SmolLM3-3B-ONNX"; | |
static async getInstance(progress_callback = null) { | |
this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, { | |
progress_callback, | |
}); | |
this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, { | |
dtype: "q4f16", | |
device: "webgpu", | |
progress_callback, | |
}); | |
return Promise.all([this.tokenizer, this.model]); | |
} | |
} | |
const stopping_criteria = new InterruptableStoppingCriteria(); | |
let past_key_values_cache = null; | |
async function generate({ messages, reasonEnabled }) { | |
const [tokenizer, model] = await TextGenerationPipeline.getInstance(); | |
const inputs = tokenizer.apply_chat_template(messages, { | |
enable_thinking: reasonEnabled, | |
add_generation_prompt: true, | |
return_dict: true, | |
}); | |
const [START_THINKING_TOKEN_ID, END_THINKING_TOKEN_ID] = tokenizer.encode( | |
"<think></think>", | |
{ add_special_tokens: false }, | |
); | |
let state = "answering"; // 'thinking' or 'answering' | |
let startTime; | |
let numTokens = 0; | |
let tps; | |
const token_callback_function = (tokens) => { | |
startTime ??= performance.now(); | |
if (numTokens++ > 0) { | |
tps = (numTokens / (performance.now() - startTime)) * 1000; | |
} | |
switch (Number(tokens[0])) { | |
case START_THINKING_TOKEN_ID: | |
state = "thinking"; | |
break; | |
case END_THINKING_TOKEN_ID: | |
state = "answering"; | |
break; | |
} | |
}; | |
const callback_function = (output) => { | |
self.postMessage({ | |
status: "update", | |
output, | |
tps, | |
numTokens, | |
state, | |
}); | |
}; | |
const streamer = new TextStreamer(tokenizer, { | |
skip_prompt: true, | |
skip_special_tokens: true, | |
callback_function, | |
token_callback_function, | |
}); | |
// Tell the main thread we are starting | |
self.postMessage({ status: "start" }); | |
const { past_key_values, sequences } = await model.generate({ | |
...inputs, | |
past_key_values: past_key_values_cache, | |
// Sampling | |
do_sample: !reasonEnabled, | |
repetition_penalty: reasonEnabled ? 1.1 : undefined, | |
top_k: 3, | |
max_new_tokens: reasonEnabled ? 4096 : 1024, | |
streamer, | |
stopping_criteria, | |
return_dict_in_generate: true, | |
}); | |
past_key_values_cache = past_key_values; | |
const decoded = tokenizer.batch_decode(sequences, { | |
skip_special_tokens: true, | |
}); | |
// Send the output back to the main thread | |
self.postMessage({ | |
status: "complete", | |
output: decoded, | |
}); | |
} | |
async function load() { | |
self.postMessage({ | |
status: "loading", | |
data: "Loading model...", | |
}); | |
// Load the pipeline and save it for future use. | |
const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => { | |
// We also add a progress callback to the pipeline so that we can | |
// track model loading. | |
self.postMessage(x); | |
}); | |
self.postMessage({ | |
status: "loading", | |
data: "Compiling shaders and warming up model...", | |
}); | |
// Run model with dummy input to compile shaders | |
const inputs = tokenizer("a"); | |
await model.generate({ ...inputs, max_new_tokens: 1 }); | |
self.postMessage({ status: "ready" }); | |
} | |
// Listen for messages from the main thread | |
self.addEventListener("message", async (e) => { | |
const { type, data } = e.data; | |
switch (type) { | |
case "check": | |
check(); | |
break; | |
case "load": | |
load(); | |
break; | |
case "generate": | |
stopping_criteria.reset(); | |
generate(data); | |
break; | |
case "interrupt": | |
stopping_criteria.interrupt(); | |
break; | |
case "reset": | |
past_key_values_cache = null; | |
stopping_criteria.reset(); | |
break; | |
} | |
}); | |