File size: 3,626 Bytes
c3ece4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import { useState, useCallback, useRef } from "react";
import {
  VoxtralForConditionalGeneration,
  VoxtralProcessor,
  TextStreamer,
  InterruptableStoppingCriteria,
} from "@huggingface/transformers";

type VoxtralStatus = "idle" | "loading" | "ready" | "transcribing" | "error";

export function useVoxtral() {
  const [status, setStatus] = useState<VoxtralStatus>("idle");
  const [error, setError] = useState<string | null>(null);
  const [transcription, setTranscription] = useState<string>("");

  const processorRef = useRef<any>((window as any).__VOXTRAL_PROCESSOR__ || null);
  const modelRef = useRef<any>((window as any).__VOXTRAL_MODEL__ || null);
  const stoppingCriteriaRef = useRef<any>(null);

  const loadModel = useCallback(async () => {
    setStatus("loading");
    setError(null);
    try {
      if (!processorRef.current || !modelRef.current) {
        const model_id = "onnx-community/Voxtral-Mini-3B-2507-ONNX";
        const processor = await VoxtralProcessor.from_pretrained(model_id);
        const model = await VoxtralForConditionalGeneration.from_pretrained(model_id, {
          dtype: {
            embed_tokens: "q4", // 252 MB
            audio_encoder: "q4", // 440 MB
            decoder_model_merged: "q4f16", // 2.0 GB
          },
          device: {
            embed_tokens: "wasm", // Just a look-up, so can be wasm
            audio_encoder: "webgpu",
            decoder_model_merged: "webgpu",
          },
        });
        processorRef.current = processor;
        modelRef.current = model;

        // Store globally to persist across hot reloads
        (window as any).__VOXTRAL_PROCESSOR__ = processor;
        (window as any).__VOXTRAL_MODEL__ = model;
      }
      setStatus("ready");
    } catch (err: any) {
      setStatus("error");
      setError("Failed to load model: " + (err?.message || err));
    }
  }, []);

  const transcribe = useCallback(async (audio: Float32Array, language: string = "en") => {
    const processor = processorRef.current;
    const model = modelRef.current;
    if (!processor || !model) {
      setError("Model not loaded");
      setStatus("error");
      return;
    }
    setStatus("transcribing");
    setTranscription("");
    setError(null);
    try {
      const conversation = [
        {
          role: "user",
          content: [{ type: "audio" }, { type: "text", text: `lang:${language}[TRANSCRIBE]` }],
        },
      ];
      const text = processor.apply_chat_template(conversation, { tokenize: false });
      const inputs = await processor(text, audio);

      let output = "";
      const streamer = new TextStreamer(processor.tokenizer, {
        skip_special_tokens: true,
        skip_prompt: true,
        callback_function: (token: string) => {
          output += token;
          setTranscription((prev) => prev + token);
        },
      });

      stoppingCriteriaRef.current = new InterruptableStoppingCriteria();

      await model.generate({
        ...inputs,
        max_new_tokens: 8192,
        streamer,
        stopping_criteria: stoppingCriteriaRef.current,
      });

      setStatus("ready");
      return output;
    } catch (err: any) {
      setStatus("error");
      setError("Transcription failed: " + (err?.message || err));
    } finally {
      stoppingCriteriaRef.current = null;
    }
  }, []);

  const stopTranscription = useCallback(() => {
    if (stoppingCriteriaRef.current) {
      stoppingCriteriaRef.current.interrupt();
    }
  }, []);

  return {
    status,
    error,
    transcription,
    loadModel,
    transcribe,
    setTranscription,
    stopTranscription,
  };
}