Luigi commited on
Commit
548b7ed
·
1 Parent(s): 30a3b5d

add endpoint detection

Browse files
Files changed (3) hide show
  1. app/asr_worker.py +13 -0
  2. app/main.py +29 -3
  3. app/static/index.html +124 -23
app/asr_worker.py CHANGED
@@ -185,6 +185,9 @@ def create_recognizer(
185
  precision: str,
186
  hotwords: List[str] = None,
187
  hotwords_score: float = 0.0,
 
 
 
188
  ):
189
  if model_id not in STREAMING_ZIPFORMER_MODELS:
190
  raise ValueError(f"Model '{model_id}' is not registered.")
@@ -262,6 +265,11 @@ def create_recognizer(
262
  hotwords_score=hotwords_score,
263
  modeling_unit=modeling_unit,
264
  bpe_vocab=bpe_vocab_path,
 
 
 
 
 
265
  )
266
 
267
  # ——— Fallback to original greedy-search (no hotword biasing) ———
@@ -275,6 +283,11 @@ def create_recognizer(
275
  sample_rate=16000,
276
  feature_dim=80,
277
  decoding_method="greedy_search",
 
 
 
 
 
278
  )
279
 
280
  def stream_audio(raw_pcm_bytes, stream, recognizer, orig_sr):
 
185
  precision: str,
186
  hotwords: List[str] = None,
187
  hotwords_score: float = 0.0,
188
+ ep_rule1: float = 2.4,
189
+ ep_rule2: float = 1.2,
190
+ ep_rule3: int = 300,
191
  ):
192
  if model_id not in STREAMING_ZIPFORMER_MODELS:
193
  raise ValueError(f"Model '{model_id}' is not registered.")
 
265
  hotwords_score=hotwords_score,
266
  modeling_unit=modeling_unit,
267
  bpe_vocab=bpe_vocab_path,
268
+ # endpoint detection parameters
269
+ enable_endpoint_detection=True,
270
+ rule1_min_trailing_silence=ep_rule1,
271
+ rule2_min_trailing_silence=ep_rule2,
272
+ rule3_min_utterance_length=ep_rule3,
273
  )
274
 
275
  # ——— Fallback to original greedy-search (no hotword biasing) ———
 
283
  sample_rate=16000,
284
  feature_dim=80,
285
  decoding_method="greedy_search",
286
+ # endpoint detection parameters
287
+ enable_endpoint_detection=True,
288
+ rule1_min_trailing_silence=ep_rule1,
289
+ rule2_min_trailing_silence=ep_rule2,
290
+ rule3_min_utterance_length=ep_rule3,
291
  )
292
 
293
  def stream_audio(raw_pcm_bytes, stream, recognizer, orig_sr):
app/main.py CHANGED
@@ -56,12 +56,21 @@ async def websocket_endpoint(websocket: WebSocket):
56
  hotwords_score = float(config_msg.get("hotwordsScore", 0.0))
57
  print(f"[INFO main] Hotwords: {hotwords}, score: {hotwords_score}")
58
 
59
- # 4) create recognizer with biasing
 
 
 
 
 
 
60
  recognizer = create_recognizer(
61
  model_id,
62
  precision,
63
  hotwords=hotwords,
64
- hotwords_score=hotwords_score
 
 
 
65
  )
66
  stream = recognizer.create_stream()
67
  print("[INFO main] WebSocket connection accepted; created a streaming context.")
@@ -78,8 +87,20 @@ async def websocket_endpoint(websocket: WebSocket):
78
  result, rms = stream_audio(raw_audio, stream, recognizer, orig_sr)
79
  vol_to_send = min(rms, 1.0)
80
  # print(f"[INFO main] Sending → partial='{result[:30]}…', volume={vol_to_send:.4f}")
 
81
  await websocket.send_json({"partial": result, "volume": vol_to_send})
82
- continue
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  elif kind == "websocket.receive_bytes":
85
  raw_audio = data["bytes"]
@@ -95,6 +116,11 @@ async def websocket_endpoint(websocket: WebSocket):
95
  "partial": result,
96
  "volume": min(rms, 1.0)
97
  })
 
 
 
 
 
98
  except Exception as e:
99
  print(f"[ERROR main] Unexpected exception: {e}")
100
  try:
 
56
  hotwords_score = float(config_msg.get("hotwordsScore", 0.0))
57
  print(f"[INFO main] Hotwords: {hotwords}, score: {hotwords_score}")
58
 
59
+ # 4) Parse endpoint detection rules
60
+ ep1 = float(config_msg.get("epRule1", 2.4))
61
+ ep2 = float(config_msg.get("epRule2", 1.2))
62
+ ep3 = int( config_msg.get("epRule3", 300))
63
+ print(f"[INFO main] Endpoint rules: rule1={ep1}s, rule2={ep2}s, rule3={ep3}ms")
64
+
65
+ # 5) create recognizer with endpoint settings & biasing
66
  recognizer = create_recognizer(
67
  model_id,
68
  precision,
69
  hotwords=hotwords,
70
+ hotwords_score=hotwords_score,
71
+ ep_rule1=ep1,
72
+ ep_rule2=ep2,
73
+ ep_rule3=ep3
74
  )
75
  stream = recognizer.create_stream()
76
  print("[INFO main] WebSocket connection accepted; created a streaming context.")
 
87
  result, rms = stream_audio(raw_audio, stream, recognizer, orig_sr)
88
  vol_to_send = min(rms, 1.0)
89
  # print(f"[INFO main] Sending → partial='{result[:30]}…', volume={vol_to_send:.4f}")
90
+ # 1) send the interim
91
  await websocket.send_json({"partial": result, "volume": vol_to_send})
92
+
93
+ # 2) DEBUG: log when endpoint is seen
94
+ is_ep = recognizer.is_endpoint(stream)
95
+ # print(f"[DEBUG main] is_endpoint={is_ep}")
96
+
97
+ # 3) if endpoint, emit final and reset
98
+ if is_ep:
99
+ if result.strip():
100
+ print(f"[DEBUG main] Emitting final: {result!r}")
101
+ await websocket.send_json({"final": result})
102
+ recognizer.reset(stream)
103
+ continue
104
 
105
  elif kind == "websocket.receive_bytes":
106
  raw_audio = data["bytes"]
 
116
  "partial": result,
117
  "volume": min(rms, 1.0)
118
  })
119
+ # -- INSERT: emit final on endpoint detection --
120
+ if recognizer.is_endpoint(stream):
121
+ if result.strip():
122
+ await websocket.send_json({"final": result})
123
+ recognizer.reset(stream)
124
  except Exception as e:
125
  print(f"[ERROR main] Unexpected exception: {e}")
126
  try:
app/static/index.html CHANGED
@@ -4,6 +4,24 @@
4
  <meta charset="UTF-8" />
5
  <title>🎤 Real-Time ASR Demo</title>
6
  <style>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  body {
8
  font-family: "Segoe UI", sans-serif;
9
  background-color: #f5f6fa;
@@ -157,6 +175,10 @@
157
  </select>
158
  </div>
159
 
 
 
 
 
160
  <div class="controls">
161
  <!-- Hotwords List Input -->
162
  <label for="hotwordsList">Hotwords:</label>
@@ -173,8 +195,19 @@
173
  <span id="hotwordStatus">Hotword Bias: Off</span>
174
  </div>
175
 
176
- <div class="model-info" id="modelInfo">
177
- Languages: <span id="modelLangs"></span> | Size: <span id="modelSize"></span> MB
 
 
 
 
 
 
 
 
 
 
 
178
  </div>
179
 
180
  <div class="mic-info">
@@ -242,21 +275,6 @@
242
  modelSize.textContent = meta.size;
243
  }
244
 
245
- function sendConfig() {
246
- if (ws && ws.readyState === WebSocket.OPEN) {
247
- ws.send(JSON.stringify({
248
- type: "config",
249
- sampleRate: orig_sample_rate,
250
- model: modelSelect.value,
251
- precision: precisionSelect.value,
252
- hotwords: hotwordsList.value.split(/\r?\n/).filter(Boolean),
253
- hotwordsScore: parseFloat(boostScore.value)
254
- }));
255
- } else {
256
- console.warn("WebSocket not open yet. Cannot send config.");
257
- }
258
- }
259
-
260
  navigator.mediaDevices.getUserMedia({ audio: true }).then(stream => {
261
  const context = new AudioContext();
262
  orig_sample_rate = context.sampleRate;
@@ -270,20 +288,38 @@
270
 
271
  // Now that we know the sample rate, open the WS
272
  ws = new WebSocket(`wss://${location.host}/ws`);
273
- ws.onopen = () => sendConfig();
274
  ws.onerror = err => console.error("WebSocket error:", err);
275
  ws.onclose = () => console.log("WebSocket closed");
 
 
276
  ws.onmessage = e => {
277
  const msg = JSON.parse(e.data);
 
 
278
  if (msg.volume !== undefined) {
279
  vol.value = Math.min(msg.volume, 1.0);
280
  }
281
- if (msg.partial) {
282
- // replace content…
283
- transcript.textContent = msg.partial;
284
- // …then scroll to bottom
285
- transcript.scrollTop = transcript.scrollHeight;
 
 
286
  }
 
 
 
 
 
 
 
 
 
 
 
 
287
  };
288
 
289
  modelSelect.addEventListener("change", () => {
@@ -315,6 +351,71 @@
315
  ws.send(new Float32Array(input).buffer);
316
  };
317
  });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  </script>
319
  </body>
320
  </html>
 
4
  <meta charset="UTF-8" />
5
  <title>🎤 Real-Time ASR Demo</title>
6
  <style>
7
+ /* Ensure the transcript preserves spacing and scrolls */
8
+ #transcript {
9
+ white-space: pre-wrap;
10
+ overflow-y: auto;
11
+ }
12
+
13
+ /* Finalized utterances in green, with a bit of right-margin */
14
+ #transcript .final {
15
+ color: green;
16
+ display: inline;
17
+ margin-right: 0.5em;
18
+ }
19
+
20
+ /* Interim utterance in red */
21
+ #transcript .interim {
22
+ color: red;
23
+ display: inline;
24
+ }
25
  body {
26
  font-family: "Segoe UI", sans-serif;
27
  background-color: #f5f6fa;
 
175
  </select>
176
  </div>
177
 
178
+ <div class="model-info" id="modelInfo">
179
+ Languages: <span id="modelLangs"></span> | Size: <span id="modelSize"></span> MB
180
+ </div>
181
+
182
  <div class="controls">
183
  <!-- Hotwords List Input -->
184
  <label for="hotwordsList">Hotwords:</label>
 
195
  <span id="hotwordStatus">Hotword Bias: Off</span>
196
  </div>
197
 
198
+ <div class="controls">
199
+ <!-- ⬇️ INSERT START: Endpoint Detection Controls ⬇️ -->
200
+ <label for="epRule1">Rule 1 (silence ≥ s):</label>
201
+ <input type="number" id="epRule1" step="0.1" value="2.4">
202
+
203
+ <label for="epRule2">Rule 2 (silence ≥ s):</label>
204
+ <input type="number" id="epRule2" step="0.1" value="1.2">
205
+
206
+ <label for="epRule3">Rule 3 (min utterance ms):</label>
207
+ <input type="number" id="epRule3" step="50" value="300">
208
+
209
+ <button id="applyEndpointConfig">Apply Endpoint Config</button>
210
+ <!-- ⬆️ INSERT END: Endpoint Detection Controls ⬆️ -->
211
  </div>
212
 
213
  <div class="mic-info">
 
275
  modelSize.textContent = meta.size;
276
  }
277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  navigator.mediaDevices.getUserMedia({ audio: true }).then(stream => {
279
  const context = new AudioContext();
280
  orig_sample_rate = context.sampleRate;
 
288
 
289
  // Now that we know the sample rate, open the WS
290
  ws = new WebSocket(`wss://${location.host}/ws`);
291
+ ws.onopen = () => sendConfig();
292
  ws.onerror = err => console.error("WebSocket error:", err);
293
  ws.onclose = () => console.log("WebSocket closed");
294
+
295
+ // Unified handler for partial + final messages
296
  ws.onmessage = e => {
297
  const msg = JSON.parse(e.data);
298
+
299
+ // 1) update volume bar
300
  if (msg.volume !== undefined) {
301
  vol.value = Math.min(msg.volume, 1.0);
302
  }
303
+
304
+ // 2) distinguish “final” vs “partial”
305
+ if (msg.final !== undefined) {
306
+ finalUtterances.push(msg.final.trim());
307
+ currentInterim = "";
308
+ } else if (msg.partial !== undefined) {
309
+ currentInterim = msg.partial;
310
  }
311
+
312
+ // 3) rebuild the full, colored transcript
313
+ transcript.innerHTML =
314
+ finalUtterances
315
+ .map(u => `<span class="final">${u}</span>`)
316
+ .join("") /* margin in CSS handles spacing */
317
+ + (currentInterim
318
+ ? ` <span class="interim">${currentInterim}</span>`
319
+ : "");
320
+
321
+ // 4) auto-scroll to newest text
322
+ transcript.scrollTop = transcript.scrollHeight;
323
  };
324
 
325
  modelSelect.addEventListener("change", () => {
 
351
  ws.send(new Float32Array(input).buffer);
352
  };
353
  });
354
+
355
+ // 2) Declare state for final/interim rendering
356
+ const finalUtterances = [];
357
+ let currentInterim = "";
358
+
359
+ // 3) Grab your new inputs + button
360
+ const epRule1Input = document.getElementById("epRule1");
361
+ const epRule2Input = document.getElementById("epRule2");
362
+ const epRule3Input = document.getElementById("epRule3");
363
+ const applyEndpointBtn = document.getElementById("applyEndpointConfig");
364
+
365
+ // 4) Extend sendConfig() to include epRule1/2/3
366
+ function sendConfig() {
367
+ if (ws && ws.readyState === WebSocket.OPEN) {
368
+ ws.send(JSON.stringify({
369
+ type: "config",
370
+ sampleRate: orig_sample_rate,
371
+ model: modelSelect.value,
372
+ precision: precisionSelect.value,
373
+ hotwords: hotwordsList.value.split(/\r?\n/).filter(Boolean),
374
+ hotwordsScore: parseFloat(boostScore.value),
375
+
376
+ // ← new endpoint fields
377
+ epRule1: parseFloat(epRule1Input.value),
378
+ epRule2: parseFloat(epRule2Input.value),
379
+ epRule3: parseInt( epRule3Input.value, 10),
380
+ }));
381
+ }
382
+ }
383
+
384
+ // 5) Re-send config when user clicks “Apply Endpoint Config”
385
+ applyEndpointBtn.addEventListener("click", () => {
386
+ sendConfig();
387
+ });
388
+
389
+ // 6) Replace your existing ws.onmessage handler with this:
390
+ ws.onmessage = e => {
391
+ const msg = JSON.parse(e.data);
392
+
393
+ if (msg.volume !== undefined) {
394
+ vol.value = Math.min(msg.volume, 1.0);
395
+ }
396
+
397
+ if (msg.final !== undefined) {
398
+ // endpoint fired → lock in the final utterance
399
+ finalUtterances.push(msg.final.trim());
400
+ currentInterim = "";
401
+ } else if (msg.partial !== undefined) {
402
+ // update the rolling interim
403
+ currentInterim = msg.partial;
404
+ }
405
+
406
+ // rebuild the full transcript: green finals + red interim
407
+ transcript.innerHTML =
408
+ finalUtterances
409
+ .map(u => `<span class="final">${u}</span>`)
410
+ .join("") // no explicit space here, margin handles it
411
+ + (currentInterim
412
+ ? `<span class="interim">${currentInterim}</span>`
413
+ : "");
414
+
415
+ // always scroll to bottom
416
+ transcript.scrollTop = transcript.scrollHeight;
417
+ };
418
+
419
  </script>
420
  </body>
421
  </html>