Luigi commited on
Commit
454a10d
·
1 Parent(s): 0586d3c

add more models

Browse files
Files changed (3) hide show
  1. app/asr_worker.py +133 -23
  2. app/main.py +25 -13
  3. app/static/index.html +50 -1
app/asr_worker.py CHANGED
@@ -1,10 +1,10 @@
1
  import os
 
2
  import numpy as np
3
  import sherpa_onnx
4
  import scipy.signal
5
  from opencc import OpenCC
6
  from huggingface_hub import hf_hub_download
7
- from pathlib import Path
8
 
9
  # Ensure Hugging Face cache is in a user-writable directory
10
  CACHE_DIR = Path(__file__).parent / "hf_cache"
@@ -12,35 +12,145 @@ os.makedirs(CACHE_DIR, exist_ok=True)
12
 
13
  converter = OpenCC('s2t')
14
 
15
- # ASR model repository and file paths
16
- REPO_ID = "pfluo/k2fsa-zipformer-chinese-english-mixed"
17
- FILES = {
18
- "tokens": "data/lang_char_bpe/tokens.txt",
19
- "encoder": "exp/encoder-epoch-99-avg-1.int8.onnx",
20
- "decoder": "exp/decoder-epoch-99-avg-1.onnx",
21
- "joiner": "exp/joiner-epoch-99-avg-1.int8.onnx",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  }
23
 
24
- # Download and cache each file via HuggingFace Hub
25
- LOCAL_PATHS = {}
26
- for key, path in FILES.items():
27
- LOCAL_PATHS[key] = hf_hub_download(
28
- repo_id=REPO_ID,
29
- filename=path,
30
- cache_dir=str(CACHE_DIR),
31
- )
32
-
33
  # Audio resampling utility
34
  def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
35
  return scipy.signal.resample_poly(audio, target_sr, orig_sr)
36
 
37
- # Build the online recognizer with int8 weights
38
- def create_recognizer():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  return sherpa_onnx.OnlineRecognizer.from_transducer(
40
- tokens=LOCAL_PATHS['tokens'],
41
- encoder=LOCAL_PATHS['encoder'],
42
- decoder=LOCAL_PATHS['decoder'],
43
- joiner=LOCAL_PATHS['joiner'],
44
  provider="cpu",
45
  num_threads=1,
46
  sample_rate=16000,
 
1
  import os
2
+ from pathlib import Path
3
  import numpy as np
4
  import sherpa_onnx
5
  import scipy.signal
6
  from opencc import OpenCC
7
  from huggingface_hub import hf_hub_download
 
8
 
9
  # Ensure Hugging Face cache is in a user-writable directory
10
  CACHE_DIR = Path(__file__).parent / "hf_cache"
 
12
 
13
  converter = OpenCC('s2t')
14
 
15
+ # Streaming Zipformer model registry: paths relative to repo root
16
+ STREAMING_ZIPFORMER_MODELS = {
17
+ "pfluo/k2fsa-zipformer-chinese-english-mixed": {
18
+ "tokens": "data/lang_char_bpe/tokens.txt",
19
+ "encoder_fp32": "exp/encoder-epoch-99-avg-1.onnx",
20
+ "encoder_int8": "exp/encoder-epoch-99-avg-1.int8.onnx",
21
+ "decoder_fp32": "exp/decoder-epoch-99-avg-1.onnx",
22
+ "decoder_int8": None,
23
+ "joiner_fp32": "exp/joiner-epoch-99-avg-1.onnx",
24
+ "joiner_int8": "exp/joiner-epoch-99-avg-1.int8.onnx",
25
+ },
26
+ "k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16": {
27
+ "tokens": "tokens.txt",
28
+ "encoder_fp32": "encoder-epoch-99-avg-1.onnx",
29
+ "encoder_int8": "encoder-epoch-99-avg-1.int8.onnx",
30
+ "decoder_fp32": "decoder-epoch-99-avg-1.onnx",
31
+ "decoder_int8": "decoder-epoch-99-avg-1.int8.onnx",
32
+ "joiner_fp32": "joiner-epoch-99-avg-1.onnx",
33
+ "joiner_int8": "joiner-epoch-99-avg-1.int8.onnx",
34
+ },
35
+ "k2-fsa/sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-12": {
36
+ "tokens": "tokens.txt",
37
+ "encoder_fp32": "encoder-epoch-20-avg-1-chunk-16-left-128.onnx",
38
+ "encoder_int8": "encoder-epoch-20-avg-1-chunk-16-left-128.int8.onnx",
39
+ "decoder_fp32": "decoder-epoch-20-avg-1-chunk-16-left-128.onnx",
40
+ "decoder_int8": "decoder-epoch-20-avg-1-chunk-16-left-128.int8.onnx",
41
+ "joiner_fp32": "joiner-epoch-20-avg-1-chunk-16-left-128.onnx",
42
+ "joiner_int8": "joiner-epoch-20-avg-1-chunk-16-left-128.int8.onnx",
43
+ },
44
+ "pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615": {
45
+ "tokens": "data/lang_char/tokens.txt",
46
+ "encoder_fp32": "exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx",
47
+ "encoder_int8": "exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx",
48
+ "decoder_fp32": "exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
49
+ "decoder_int8": "exp/decoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx",
50
+ "joiner_fp32": "exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
51
+ "joiner_int8": "exp/joiner-epoch-12-avg-4-chunk-16-left-128.int8.onnx",
52
+ },
53
+ "csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26": {
54
+ "tokens": "tokens.txt",
55
+ "encoder_fp32": "encoder-epoch-99-avg-1-chunk-16-left-128.onnx",
56
+ "encoder_int8": "encoder-epoch-99-avg-1-chunk-16-left-128.int8.onnx",
57
+ "decoder_fp32": "decoder-epoch-99-avg-1-chunk-16-left-128.onnx",
58
+ "decoder_int8": None,
59
+ "joiner_fp32": "joiner-epoch-99-avg-1-chunk-16-left-128.onnx",
60
+ "joiner_int8": "joiner-epoch-99-avg-1-chunk-16-left-128.int8.onnx",
61
+ },
62
+ "csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-21": {
63
+ "tokens": "tokens.txt",
64
+ "encoder_fp32": "encoder-epoch-99-avg-1.onnx",
65
+ "encoder_int8": "encoder-epoch-99-avg-1.int8.onnx",
66
+ "decoder_fp32": "decoder-epoch-99-avg-1.onnx",
67
+ "decoder_int8": "decoder-epoch-99-avg-1.int8.onnx",
68
+ "joiner_fp32": "joiner-epoch-99-avg-1.onnx",
69
+ "joiner_int8": "joiner-epoch-99-avg-1.int8.onnx",
70
+ },
71
+ "csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21": {
72
+ "tokens": "tokens.txt",
73
+ "encoder_fp32": "encoder-epoch-99-avg-1.onnx",
74
+ "encoder_int8": "encoder-epoch-99-avg-1.int8.onnx",
75
+ "decoder_fp32": "decoder-epoch-99-avg-1.onnx",
76
+ "decoder_int8": "decoder-epoch-99-avg-1.int8.onnx",
77
+ "joiner_fp32": "joiner-epoch-99-avg-1.onnx",
78
+ "joiner_int8": "joiner-epoch-99-avg-1.int8.onnx",
79
+ },
80
+ "csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20": {
81
+ "tokens": "tokens.txt",
82
+ "encoder_fp32": "encoder-epoch-99-avg-1.onnx",
83
+ "encoder_int8": "encoder-epoch-99-avg-1.int8.onnx",
84
+ "decoder_fp32": "decoder-epoch-99-avg-1.onnx",
85
+ "decoder_int8": "decoder-epoch-99-avg-1.int8.onnx",
86
+ "joiner_fp32": "joiner-epoch-99-avg-1.onnx",
87
+ "joiner_int8": "joiner-epoch-99-avg-1.int8.onnx",
88
+ },
89
+ "shaojieli/sherpa-onnx-streaming-zipformer-fr-2023-04-14": {
90
+ "tokens": "tokens.txt",
91
+ "encoder_fp32": "encoder-epoch-29-avg-9-with-averaged-model.onnx",
92
+ "encoder_int8": "encoder-epoch-29-avg-9-with-averaged-model.int8.onnx",
93
+ "decoder_fp32": "decoder-epoch-29-avg-9-with-averaged-model.onnx",
94
+ "decoder_int8": "decoder-epoch-29-avg-9-with-averaged-model.int8.onnx",
95
+ "joiner_fp32": "joiner-epoch-29-avg-9-with-averaged-model.onnx",
96
+ "joiner_int8": "joiner-epoch-29-avg-9-with-averaged-model.int8.onnx",
97
+ },
98
+ "sherpa-onnx-streaming-zipformer-small-bilingual-zh-en-2023-02-16": {
99
+ "tokens": "tokens.txt",
100
+ "encoder_fp32": "encoder-epoch-99-avg-1.onnx",
101
+ "encoder_int8": "encoder-epoch-99-avg-1.int8.onnx",
102
+ "decoder_fp32": "decoder-epoch-99-avg-1.onnx",
103
+ "decoder_int8": "decoder-epoch-99-avg-1.int8.onnx",
104
+ "joiner_fp32": "joiner-epoch-99-avg-1.onnx",
105
+ "joiner_int8": "joiner-epoch-99-avg-1.int8.onnx",
106
+ },
107
+ "csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23": {
108
+ "tokens": "tokens.txt",
109
+ "encoder_fp32": "encoder-epoch-99-avg-1.onnx",
110
+ "encoder_int8": "encoder-epoch-99-avg-1.int8.onnx",
111
+ "decoder_fp32": "decoder-epoch-99-avg-1.onnx",
112
+ "decoder_int8": "decoder-epoch-99-avg-1.int8.onnx",
113
+ "joiner_fp32": "joiner-epoch-99-avg-1.onnx",
114
+ "joiner_int8": "joiner-epoch-99-avg-1.int8.onnx",
115
+ },
116
+ "csukuangfj/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17": {
117
+ "tokens": "tokens.txt",
118
+ "encoder_fp32": "encoder-epoch-99-avg-1.onnx",
119
+ "encoder_int8": "encoder-epoch-99-avg-1.int8.onnx",
120
+ "decoder_fp32": "decoder-epoch-99-avg-1.onnx",
121
+ "decoder_int8": "decoder-epoch-99-avg-1.int8.onnx",
122
+ "joiner_fp32": "joiner-epoch-99-avg-1.onnx",
123
+ "joiner_int8": "joiner-epoch-99-avg-1.int8.onnx",
124
+ },
125
  }
126
 
 
 
 
 
 
 
 
 
 
127
  # Audio resampling utility
128
  def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
129
  return scipy.signal.resample_poly(audio, target_sr, orig_sr)
130
 
131
+ # Create an online recognizer for a given model and precision
132
+ # model_id: full HF repo ID
133
+ # precision: "int8" or "fp32"
134
+ def create_recognizer(model_id: str, precision: str):
135
+ if model_id not in STREAMING_ZIPFORMER_MODELS:
136
+ raise ValueError(f"Model '{model_id}' is not registered.")
137
+ entry = STREAMING_ZIPFORMER_MODELS[model_id]
138
+
139
+ tokens_file = entry['tokens']
140
+ encoder_file = entry['encoder_int8'] if precision == 'int8' else entry['encoder_fp32']
141
+ decoder_file = entry['decoder_fp32']
142
+ joiner_file = entry['joiner_int8'] if precision == 'int8' else entry['joiner_fp32']
143
+
144
+ tokens_path = hf_hub_download(repo_id=model_id, filename=tokens_file, cache_dir=str(CACHE_DIR))
145
+ encoder_path = hf_hub_download(repo_id=model_id, filename=encoder_file, cache_dir=str(CACHE_DIR))
146
+ decoder_path = hf_hub_download(repo_id=model_id, filename=decoder_file, cache_dir=str(CACHE_DIR))
147
+ joiner_path = hf_hub_download(repo_id=model_id, filename=joiner_file, cache_dir=str(CACHE_DIR))
148
+
149
  return sherpa_onnx.OnlineRecognizer.from_transducer(
150
+ tokens=tokens_path,
151
+ encoder=encoder_path,
152
+ decoder=decoder_path,
153
+ joiner=joiner_path,
154
  provider="cpu",
155
  num_threads=1,
156
  sample_rate=16000,
app/main.py CHANGED
@@ -8,8 +8,6 @@ app = FastAPI()
8
 
9
  app.mount("/static", StaticFiles(directory="app/static"), name="static")
10
 
11
- recognizer = create_recognizer()
12
-
13
  @app.get("/")
14
  async def root():
15
  with open("app/static/index.html") as f:
@@ -22,24 +20,24 @@ async def websocket_endpoint(websocket: WebSocket):
22
  await websocket.accept()
23
  print("[DEBUG main] ▶ WebSocket.accept() returned → client is connected!")
24
 
25
- # Immediately create a new stream per client
26
- stream = recognizer.create_stream()
27
  orig_sr = 48000 # default fallback
28
- print("[INFO main] WebSocket connection accepted; created a streaming context.")
29
 
30
  try:
31
  while True:
32
  data = await websocket.receive()
33
  kind = data.get("type")
34
 
35
- # Debug: log any event we don't handle explicitly
36
  if kind not in ("websocket.receive", "websocket.receive_bytes"):
37
  print(f"[DEBUG main] Received control/frame: {data}")
38
- # If client cleanly disconnected, finalize and break
39
  if kind == "websocket.disconnect":
40
- print(f"[INFO main] Client disconnected (code={data.get('code')}). Flushing final transcript...")
41
- final = finalize_stream(stream, recognizer)
42
- await websocket.send_json({"final": final})
 
 
43
  break
44
  continue
45
 
@@ -54,7 +52,20 @@ async def websocket_endpoint(websocket: WebSocket):
54
  if config_msg.get("type") == "config":
55
  orig_sr = int(config_msg["sampleRate"])
56
  print(f"[INFO main] Set original sample rate to {orig_sr}")
57
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # If it’s a text payload but with bytes (some FastAPI versions put audio under 'text'!)
60
  if kind == "websocket.receive" and "bytes" in data:
@@ -82,7 +93,8 @@ async def websocket_endpoint(websocket: WebSocket):
82
  })
83
  except Exception as e:
84
  print(f"[ERROR main] Unexpected exception: {e}")
85
- final = finalize_stream(stream, recognizer)
86
- await websocket.send_json({"final": final})
 
87
  await websocket.close()
88
  print("[INFO main] WebSocket closed, cleanup complete.")
 
8
 
9
  app.mount("/static", StaticFiles(directory="app/static"), name="static")
10
 
 
 
11
  @app.get("/")
12
  async def root():
13
  with open("app/static/index.html") as f:
 
20
  await websocket.accept()
21
  print("[DEBUG main] ▶ WebSocket.accept() returned → client is connected!")
22
 
23
+ recognizer = None
24
+ stream = None
25
  orig_sr = 48000 # default fallback
 
26
 
27
  try:
28
  while True:
29
  data = await websocket.receive()
30
  kind = data.get("type")
31
 
32
+ # Handle control frames
33
  if kind not in ("websocket.receive", "websocket.receive_bytes"):
34
  print(f"[DEBUG main] Received control/frame: {data}")
 
35
  if kind == "websocket.disconnect":
36
+ # On client disconnect, flush final transcript if possible
37
+ if stream and recognizer:
38
+ print(f"[INFO main] Client disconnected (code={data.get('code')}). Flushing final transcript...")
39
+ final = finalize_stream(stream, recognizer)
40
+ await websocket.send_json({"final": final})
41
  break
42
  continue
43
 
 
52
  if config_msg.get("type") == "config":
53
  orig_sr = int(config_msg["sampleRate"])
54
  print(f"[INFO main] Set original sample rate to {orig_sr}")
55
+
56
+ # New: dynamic model & precision
57
+ model_id = config_msg.get("model")
58
+ precision = config_msg.get("precision")
59
+ print(f"[INFO main] Selected model: {model_id}, precision: {precision}")
60
+
61
+ recognizer = create_recognizer(model_id, precision)
62
+ stream = recognizer.create_stream()
63
+ print("[INFO main] WebSocket connection accepted; created a streaming context.")
64
+ continue
65
+
66
+ # Don't process audio until after config
67
+ if recognizer is None or stream is None:
68
+ continue
69
 
70
  # If it’s a text payload but with bytes (some FastAPI versions put audio under 'text'!)
71
  if kind == "websocket.receive" and "bytes" in data:
 
93
  })
94
  except Exception as e:
95
  print(f"[ERROR main] Unexpected exception: {e}")
96
+ if stream and recognizer:
97
+ final = finalize_stream(stream, recognizer)
98
+ await websocket.send_json({"final": final})
99
  await websocket.close()
100
  print("[INFO main] WebSocket closed, cleanup complete.")
app/static/index.html CHANGED
@@ -70,10 +70,52 @@
70
  font-size: 1.4rem;
71
  color: #e84118;
72
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  </style>
74
  </head>
75
  <body>
76
  <h1>🎤 Speak into your microphone</h1>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  <progress id="vol" max="1" value="0"></progress>
78
 
79
  <div class="output">
@@ -89,6 +131,8 @@
89
  const vol = document.getElementById("vol");
90
  const partial = document.getElementById("partial");
91
  const finalText = document.getElementById("final");
 
 
92
 
93
  navigator.mediaDevices.getUserMedia({ audio: true }).then(stream => {
94
  const context = new AudioContext();
@@ -96,7 +140,12 @@
96
 
97
  ws.onopen = () => {
98
  console.log("[DEBUG client] WebSocket.onopen fired!");
99
- ws.send(JSON.stringify({ type: "config", sampleRate: orig_sample_rate }));
 
 
 
 
 
100
  };
101
  ws.onerror = err => {
102
  console.error("[DEBUG client] WebSocket.onerror:", err);
 
70
  font-size: 1.4rem;
71
  color: #e84118;
72
  }
73
+
74
+ .controls {
75
+ display: flex;
76
+ gap: 1rem;
77
+ margin-bottom: 1rem;
78
+ align-items: center;
79
+ }
80
+ .controls label {
81
+ font-weight: bold;
82
+ color: #2f3640;
83
+ }
84
+ .controls select {
85
+ padding: 0.3rem;
86
+ border-radius: 5px;
87
+ border: 1px solid #dcdde1;
88
+ background: white;
89
+ }
90
  </style>
91
  </head>
92
  <body>
93
  <h1>🎤 Speak into your microphone</h1>
94
+
95
+ <div class="controls">
96
+ <label for="modelSelect">Model:</label>
97
+ <select id="modelSelect">
98
+ <option value="pfluo/k2fsa-zipformer-chinese-english-mixed">k2fsa-chinese-english-mixed</option>
99
+ <option value="k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16">sherpa-onnx-zipformer-korean</option>
100
+ <option value="k2-fsa/sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-12">zipformer-multi-zh-hans</option>
101
+ <option value="pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615">icefall-zipformer-wenetspeech</option>
102
+ <option value="csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26">zipformer-en-06-26</option>
103
+ <option value="csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-21">zipformer-en-06-21</option>
104
+ <option value="csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21">zipformer-en-02-21</option>
105
+ <option value="csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20">zipformer-zh-en</option>
106
+ <option value="shaojieli/sherpa-onnx-streaming-zipformer-fr-2023-04-14">zipformer-fr</option>
107
+ <option value="sherpa-onnx-streaming-zipformer-small-bilingual-zh-en-2023-02-16">zipformer-small-zh-en</option>
108
+ <option value="csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23">zipformer-zh-14M</option>
109
+ <option value="csukuangfj/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17">zipformer-en-20M</option>
110
+ </select>
111
+
112
+ <label for="precisionSelect">Precision:</label>
113
+ <select id="precisionSelect">
114
+ <option value="fp32">FP32</option>
115
+ <option value="int8">INT8</option>
116
+ </select>
117
+ </div>
118
+
119
  <progress id="vol" max="1" value="0"></progress>
120
 
121
  <div class="output">
 
131
  const vol = document.getElementById("vol");
132
  const partial = document.getElementById("partial");
133
  const finalText = document.getElementById("final");
134
+ const modelSelect = document.getElementById("modelSelect");
135
+ const precisionSelect = document.getElementById("precisionSelect");
136
 
137
  navigator.mediaDevices.getUserMedia({ audio: true }).then(stream => {
138
  const context = new AudioContext();
 
140
 
141
  ws.onopen = () => {
142
  console.log("[DEBUG client] WebSocket.onopen fired!");
143
+ ws.send(JSON.stringify({
144
+ type: "config",
145
+ sampleRate: orig_sample_rate,
146
+ model: modelSelect.value,
147
+ precision: precisionSelect.value
148
+ }));
149
  };
150
  ws.onerror = err => {
151
  console.error("[DEBUG client] WebSocket.onerror:", err);