qfuxa commited on
Commit
2b4a348
Β·
1 Parent(s): 7643321

use confidence scores returned by whisper to immediately validate tokens

Browse files
README.md CHANGED
@@ -1,4 +1,4 @@
1
- # Real-time, fully local Speech-to-Text and speaker diarization using FastAPI WebSockets with a web interface
2
 
3
  This project is based on [Whisper Streaming](https://github.com/ufal/whisper_streaming) and lets you transcribe audio directly from your browser. Simply launch the local server and grant microphone access. Everything runs locally on your machine ✨
4
 
@@ -8,24 +8,23 @@ This project is based on [Whisper Streaming](https://github.com/ufal/whisper_str
8
 
9
  ### Differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
10
 
11
- #### 🌐 **Web & API**
12
- - **Built-in Web UI** – No frontend setup required, just open your browser and start transcribing.
13
- - **FastAPI WebSocket Server** – Real-time speech-to-text processing with async FFmpeg streaming.
14
- - **JavaScript Client** – Ready-to-use MediaRecorder implementation for seamless client-side integration.
15
-
16
  #### βš™οΈ **Core Improvements**
17
  - **Buffering Preview** – Displays unvalidated transcription segments for immediate feedback.
18
  - **Multi-User Support** – Handles multiple users simultaneously without conflicts.
19
  - **MLX Whisper Backend** – Optimized for Apple Silicon for faster local processing.
20
  - **Enhanced Sentence Segmentation** – Improved buffer trimming for better accuracy across languages.
21
- - **Extended Logging** – More detailed logs to improve debugging and monitoring.
22
 
23
- #### πŸŽ™οΈ **Advanced Features**
24
- - **Real-Time Diarization** – Identify different speakers in real time using [Diart](https://github.com/juanmc2005/diart).
 
 
 
 
 
25
 
26
  #### πŸš€ **Coming Soon**
27
 
28
- - **Faster Word Validation** – Accelerate real-time transcription by validating high-confidence words immediately upon first appearance for whisper backends that return word & segment probabilities
29
  - **Enhanced Diarization Performance** – Optimize speaker identification by implementing longer steps for Diart processing and leveraging language-specific segmentation patterns to improve speaker boundary detection
30
 
31
 
@@ -87,12 +86,13 @@ This project is based on [Whisper Streaming](https://github.com/ufal/whisper_str
87
  python whisper_fastapi_online_server.py --host 0.0.0.0 --port 8000
88
  ```
89
 
 
 
90
  - `--host` and `--port` let you specify the server’s IP/port.
91
  - `-min-chunk-size` sets the minimum chunk size for audio processing. Make sure this value aligns with the chunk size selected in the frontend. If not aligned, the system will work but may unnecessarily over-process audio data.
92
- - For a full list of configurable options, run `python whisper_fastapi_online_server.py -h`
93
- - `--transcription`, default to True. Change to False if you want to run only diarization
94
- - `--diarization`, default to False, let you choose whether or not you want to run diarization in parallel
95
- - For other parameters, look at [whisper streaming](https://github.com/ufal/whisper_streaming) readme.
96
 
97
  4. **Open the Provided HTML**:
98
 
 
1
+ # Real-time, Fully Local Speech-to-Text and Speaker Diarization
2
 
3
  This project is based on [Whisper Streaming](https://github.com/ufal/whisper_streaming) and lets you transcribe audio directly from your browser. Simply launch the local server and grant microphone access. Everything runs locally on your machine ✨
4
 
 
8
 
9
  ### Differences from [Whisper Streaming](https://github.com/ufal/whisper_streaming)
10
 
 
 
 
 
 
11
  #### βš™οΈ **Core Improvements**
12
  - **Buffering Preview** – Displays unvalidated transcription segments for immediate feedback.
13
  - **Multi-User Support** – Handles multiple users simultaneously without conflicts.
14
  - **MLX Whisper Backend** – Optimized for Apple Silicon for faster local processing.
15
  - **Enhanced Sentence Segmentation** – Improved buffer trimming for better accuracy across languages.
16
+ - **Confidence validation** – Immediately validate high-confidence tokens for faster inference
17
 
18
+ #### πŸŽ™οΈ **Speaker Identification**
19
+ - **Real-Time Diarization** – Identify different speakers in real time using [Diart](https://github.com/juanmc2005/diart).
20
+
21
+ #### 🌐 **Web & API**
22
+ - **Built-in Web UI** – Simple browser interface with no frontend setup required
23
+ - **FastAPI WebSocket Server** – Real-time speech-to-text processing with async FFmpeg streaming.
24
+ - **JavaScript Client** – Ready-to-use MediaRecorder implementation for seamless client-side integration.
25
 
26
  #### πŸš€ **Coming Soon**
27
 
 
28
  - **Enhanced Diarization Performance** – Optimize speaker identification by implementing longer steps for Diart processing and leveraging language-specific segmentation patterns to improve speaker boundary detection
29
 
30
 
 
86
  python whisper_fastapi_online_server.py --host 0.0.0.0 --port 8000
87
  ```
88
 
89
+ All [Whisper Streaming](https://github.com/ufal/whisper_streaming) parameters are supported.
90
+ Additional parameters:
91
  - `--host` and `--port` let you specify the server’s IP/port.
92
  - `-min-chunk-size` sets the minimum chunk size for audio processing. Make sure this value aligns with the chunk size selected in the frontend. If not aligned, the system will work but may unnecessarily over-process audio data.
93
+ - `--transcription`: Enable/disable transcription (default: True)
94
+ - `--diarization`: Enable/disable speaker diarization (default: False)
95
+ - `--confidence-validation`: Use confidence scores for faster validation. Transcription will be faster but punctuation might be less accurate (default: True)
 
96
 
97
  4. **Open the Provided HTML**:
98
 
timed_objects.py CHANGED
@@ -7,12 +7,13 @@ class TimedText:
7
  end: Optional[float]
8
  text: Optional[str] = ''
9
  speaker: Optional[int] = -1
 
10
 
11
  @dataclass
12
  class ASRToken(TimedText):
13
  def with_offset(self, offset: float) -> "ASRToken":
14
  """Return a new token with the time offset added."""
15
- return ASRToken(self.start + offset, self.end + offset, self.text)
16
 
17
  @dataclass
18
  class Sentence(TimedText):
 
7
  end: Optional[float]
8
  text: Optional[str] = ''
9
  speaker: Optional[int] = -1
10
+ probability: Optional[float] = None
11
 
12
  @dataclass
13
  class ASRToken(TimedText):
14
  def with_offset(self, offset: float) -> "ASRToken":
15
  """Return a new token with the time offset added."""
16
+ return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability)
17
 
18
  @dataclass
19
  class Sentence(TimedText):
whisper_fastapi_online_server.py CHANGED
@@ -46,6 +46,13 @@ parser.add_argument(
46
  help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
47
  )
48
 
 
 
 
 
 
 
 
49
  parser.add_argument(
50
  "--diarization",
51
  type=bool,
 
46
  help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
47
  )
48
 
49
+ parser.add_argument(
50
+ "--confidence-validation",
51
+ type=bool,
52
+ default=True,
53
+ help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
54
+ )
55
+
56
  parser.add_argument(
57
  "--diarization",
58
  type=bool,
whisper_streaming_custom/backends.py CHANGED
@@ -131,7 +131,7 @@ class FasterWhisperASR(ASRBase):
131
  if segment.no_speech_prob > 0.9:
132
  continue
133
  for word in segment.words:
134
- token = ASRToken(word.start, word.end, word.word)
135
  tokens.append(token)
136
  return tokens
137
 
@@ -210,7 +210,7 @@ class MLXWhisper(ASRBase):
210
  if segment.get("no_speech_prob", 0) > 0.9:
211
  continue
212
  for word in segment.get("words", []):
213
- token = ASRToken(word["start"], word["end"], word["word"])
214
  tokens.append(token)
215
  return tokens
216
 
 
131
  if segment.no_speech_prob > 0.9:
132
  continue
133
  for word in segment.words:
134
+ token = ASRToken(word.start, word.end, word.word, probability=word.probability)
135
  tokens.append(token)
136
  return tokens
137
 
 
210
  if segment.get("no_speech_prob", 0) > 0.9:
211
  continue
212
  for word in segment.get("words", []):
213
+ token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"])
214
  tokens.append(token)
215
  return tokens
216
 
whisper_streaming_custom/online_asr.py CHANGED
@@ -16,7 +16,8 @@ class HypothesisBuffer:
16
  - buffer: the last hypothesis that is not yet committed
17
  - new: new tokens coming from the recognizer
18
  """
19
- def __init__(self, logfile=sys.stderr):
 
20
  self.committed_in_buffer: List[ASRToken] = []
21
  self.buffer: List[ASRToken] = []
22
  self.new: List[ASRToken] = []
@@ -62,9 +63,15 @@ class HypothesisBuffer:
62
  committed: List[ASRToken] = []
63
  while self.new:
64
  current_new = self.new[0]
65
- if not self.buffer:
 
 
 
 
 
 
66
  break
67
- if current_new.text == self.buffer[0].text:
68
  committed.append(current_new)
69
  self.last_committed_word = current_new.text
70
  self.last_committed_time = current_new.end
@@ -102,6 +109,7 @@ class OnlineASRProcessor:
102
  asr,
103
  tokenize_method: Optional[callable] = None,
104
  buffer_trimming: Tuple[str, float] = ("segment", 15),
 
105
  logfile=sys.stderr,
106
  ):
107
  """
@@ -114,7 +122,7 @@ class OnlineASRProcessor:
114
  self.asr = asr
115
  self.tokenize = tokenize_method
116
  self.logfile = logfile
117
-
118
  self.init()
119
 
120
  self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
@@ -131,7 +139,7 @@ class OnlineASRProcessor:
131
  def init(self, offset: Optional[float] = None):
132
  """Initialize or reset the processing buffers."""
133
  self.audio_buffer = np.array([], dtype=np.float32)
134
- self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
135
  self.buffer_time_offset = offset if offset is not None else 0.0
136
  self.transcript_buffer.last_committed_time = self.buffer_time_offset
137
  self.committed: List[ASRToken] = []
@@ -323,13 +331,14 @@ class OnlineASRProcessor:
323
  ) -> Transcript:
324
  sep = sep if sep is not None else self.asr.sep
325
  text = sep.join(token.text for token in tokens)
 
326
  if tokens:
327
  start = offset + tokens[0].start
328
  end = offset + tokens[-1].end
329
  else:
330
  start = None
331
  end = None
332
- return Transcript(start, end, text)
333
 
334
 
335
  class VACOnlineASRProcessor:
 
16
  - buffer: the last hypothesis that is not yet committed
17
  - new: new tokens coming from the recognizer
18
  """
19
+ def __init__(self, logfile=sys.stderr, confidence_validation=False):
20
+ self.confidence_validation = confidence_validation
21
  self.committed_in_buffer: List[ASRToken] = []
22
  self.buffer: List[ASRToken] = []
23
  self.new: List[ASRToken] = []
 
63
  committed: List[ASRToken] = []
64
  while self.new:
65
  current_new = self.new[0]
66
+ if self.confidence_validation and current_new.probability and current_new.probability > 0.95:
67
+ committed.append(current_new)
68
+ self.last_committed_word = current_new.text
69
+ self.last_committed_time = current_new.end
70
+ self.new.pop(0)
71
+ self.buffer.pop(0) if self.buffer else None
72
+ elif not self.buffer:
73
  break
74
+ elif current_new.text == self.buffer[0].text:
75
  committed.append(current_new)
76
  self.last_committed_word = current_new.text
77
  self.last_committed_time = current_new.end
 
109
  asr,
110
  tokenize_method: Optional[callable] = None,
111
  buffer_trimming: Tuple[str, float] = ("segment", 15),
112
+ confidence_validation = False,
113
  logfile=sys.stderr,
114
  ):
115
  """
 
122
  self.asr = asr
123
  self.tokenize = tokenize_method
124
  self.logfile = logfile
125
+ self.confidence_validation = confidence_validation
126
  self.init()
127
 
128
  self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
 
139
  def init(self, offset: Optional[float] = None):
140
  """Initialize or reset the processing buffers."""
141
  self.audio_buffer = np.array([], dtype=np.float32)
142
+ self.transcript_buffer = HypothesisBuffer(logfile=self.logfile, confidence_validation=self.confidence_validation)
143
  self.buffer_time_offset = offset if offset is not None else 0.0
144
  self.transcript_buffer.last_committed_time = self.buffer_time_offset
145
  self.committed: List[ASRToken] = []
 
331
  ) -> Transcript:
332
  sep = sep if sep is not None else self.asr.sep
333
  text = sep.join(token.text for token in tokens)
334
+ probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
335
  if tokens:
336
  start = offset + tokens[0].start
337
  end = offset + tokens[-1].end
338
  else:
339
  start = None
340
  end = None
341
+ return Transcript(start, end, text, probability=probability)
342
 
343
 
344
  class VACOnlineASRProcessor:
whisper_streaming_custom/whisper_online.py CHANGED
@@ -77,7 +77,7 @@ def add_shared_args(parser):
77
  parser.add_argument(
78
  "--model",
79
  type=str,
80
- default="tiny.en",
81
  choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
82
  ","
83
  ),
@@ -207,6 +207,7 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
207
  tokenizer,
208
  logfile=logfile,
209
  buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
 
210
  )
211
  else:
212
  online = OnlineASRProcessor(
@@ -214,6 +215,7 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
214
  tokenizer,
215
  logfile=logfile,
216
  buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
 
217
  )
218
  return online
219
 
 
77
  parser.add_argument(
78
  "--model",
79
  type=str,
80
+ default="large-v3-turbo",
81
  choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
82
  ","
83
  ),
 
207
  tokenizer,
208
  logfile=logfile,
209
  buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
210
+ confidence_validation = args.confidence_validation
211
  )
212
  else:
213
  online = OnlineASRProcessor(
 
215
  tokenizer,
216
  logfile=logfile,
217
  buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
218
+ confidence_validation = args.confidence_validation
219
  )
220
  return online
221