File size: 3,626 Bytes
c3ece4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import { useState, useCallback, useRef } from "react";
import {
VoxtralForConditionalGeneration,
VoxtralProcessor,
TextStreamer,
InterruptableStoppingCriteria,
} from "@huggingface/transformers";
type VoxtralStatus = "idle" | "loading" | "ready" | "transcribing" | "error";
export function useVoxtral() {
const [status, setStatus] = useState<VoxtralStatus>("idle");
const [error, setError] = useState<string | null>(null);
const [transcription, setTranscription] = useState<string>("");
const processorRef = useRef<any>((window as any).__VOXTRAL_PROCESSOR__ || null);
const modelRef = useRef<any>((window as any).__VOXTRAL_MODEL__ || null);
const stoppingCriteriaRef = useRef<any>(null);
const loadModel = useCallback(async () => {
setStatus("loading");
setError(null);
try {
if (!processorRef.current || !modelRef.current) {
const model_id = "onnx-community/Voxtral-Mini-3B-2507-ONNX";
const processor = await VoxtralProcessor.from_pretrained(model_id);
const model = await VoxtralForConditionalGeneration.from_pretrained(model_id, {
dtype: {
embed_tokens: "q4", // 252 MB
audio_encoder: "q4", // 440 MB
decoder_model_merged: "q4f16", // 2.0 GB
},
device: {
embed_tokens: "wasm", // Just a look-up, so can be wasm
audio_encoder: "webgpu",
decoder_model_merged: "webgpu",
},
});
processorRef.current = processor;
modelRef.current = model;
// Store globally to persist across hot reloads
(window as any).__VOXTRAL_PROCESSOR__ = processor;
(window as any).__VOXTRAL_MODEL__ = model;
}
setStatus("ready");
} catch (err: any) {
setStatus("error");
setError("Failed to load model: " + (err?.message || err));
}
}, []);
const transcribe = useCallback(async (audio: Float32Array, language: string = "en") => {
const processor = processorRef.current;
const model = modelRef.current;
if (!processor || !model) {
setError("Model not loaded");
setStatus("error");
return;
}
setStatus("transcribing");
setTranscription("");
setError(null);
try {
const conversation = [
{
role: "user",
content: [{ type: "audio" }, { type: "text", text: `lang:${language}[TRANSCRIBE]` }],
},
];
const text = processor.apply_chat_template(conversation, { tokenize: false });
const inputs = await processor(text, audio);
let output = "";
const streamer = new TextStreamer(processor.tokenizer, {
skip_special_tokens: true,
skip_prompt: true,
callback_function: (token: string) => {
output += token;
setTranscription((prev) => prev + token);
},
});
stoppingCriteriaRef.current = new InterruptableStoppingCriteria();
await model.generate({
...inputs,
max_new_tokens: 8192,
streamer,
stopping_criteria: stoppingCriteriaRef.current,
});
setStatus("ready");
return output;
} catch (err: any) {
setStatus("error");
setError("Transcription failed: " + (err?.message || err));
} finally {
stoppingCriteriaRef.current = null;
}
}, []);
const stopTranscription = useCallback(() => {
if (stoppingCriteriaRef.current) {
stoppingCriteriaRef.current.interrupt();
}
}, []);
return {
status,
error,
transcription,
loadModel,
transcribe,
setTranscription,
stopTranscription,
};
}
|