diarization now works at word - not chunk - level!
Browse files
src/diarization/diarization_online.py
CHANGED
@@ -81,11 +81,10 @@ class DiartDiarization:
|
|
81 |
def close(self):
|
82 |
self.source.close()
|
83 |
|
84 |
-
def
|
85 |
-
|
86 |
-
for chunk in chunks:
|
87 |
for segment in self.segment_speakers:
|
88 |
-
if not (segment["end"] <=
|
89 |
-
|
90 |
-
end_attributed_speaker =
|
91 |
return end_attributed_speaker
|
|
|
81 |
def close(self):
|
82 |
self.source.close()
|
83 |
|
84 |
+
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
|
85 |
+
for token in tokens:
|
|
|
86 |
for segment in self.segment_speakers:
|
87 |
+
if not (segment["end"] <= token.start or segment["beg"] >= token.end):
|
88 |
+
token.speaker = extract_number(segment["speaker"]) + 1
|
89 |
+
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
90 |
return end_attributed_speaker
|
src/whisper_streaming/online_asr.py
CHANGED
@@ -202,7 +202,7 @@ class OnlineASRProcessor:
|
|
202 |
logger.debug(
|
203 |
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
204 |
)
|
205 |
-
return
|
206 |
|
207 |
def chunk_completed_sentence(self):
|
208 |
"""
|
|
|
202 |
logger.debug(
|
203 |
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
204 |
)
|
205 |
+
return committed_tokens
|
206 |
|
207 |
def chunk_completed_sentence(self):
|
208 |
"""
|
src/whisper_streaming/timed_objects.py
CHANGED
@@ -5,7 +5,8 @@ from typing import Optional
|
|
5 |
class TimedText:
|
6 |
start: Optional[float]
|
7 |
end: Optional[float]
|
8 |
-
text: str
|
|
|
9 |
|
10 |
@dataclass
|
11 |
class ASRToken(TimedText):
|
|
|
5 |
class TimedText:
|
6 |
start: Optional[float]
|
7 |
end: Optional[float]
|
8 |
+
text: Optional[str] = ''
|
9 |
+
speaker: Optional[int] = -1
|
10 |
|
11 |
@dataclass
|
12 |
class ASRToken(TimedText):
|
whisper_fastapi_online_server.py
CHANGED
@@ -11,6 +11,7 @@ from fastapi.responses import HTMLResponse
|
|
11 |
from fastapi.middleware.cors import CORSMiddleware
|
12 |
|
13 |
from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
|
|
|
14 |
|
15 |
import math
|
16 |
import logging
|
@@ -47,7 +48,7 @@ parser.add_argument(
|
|
47 |
parser.add_argument(
|
48 |
"--diarization",
|
49 |
type=bool,
|
50 |
-
default=
|
51 |
help="Whether to enable speaker diarization.",
|
52 |
)
|
53 |
|
@@ -157,7 +158,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
157 |
full_transcription = ""
|
158 |
beg = time()
|
159 |
beg_loop = time()
|
160 |
-
|
|
|
|
|
161 |
|
162 |
while True:
|
163 |
try:
|
@@ -177,7 +180,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
177 |
logger.warning("FFmpeg read timeout. Restarting...")
|
178 |
await restart_ffmpeg()
|
179 |
full_transcription = ""
|
180 |
-
chunk_history = []
|
181 |
beg = time()
|
182 |
continue # Skip processing and read from new process
|
183 |
|
@@ -202,63 +204,53 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
202 |
if args.transcription:
|
203 |
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
|
204 |
online.insert_audio_chunk(pcm_array)
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
"beg": transcription.start,
|
209 |
-
"end": transcription.end,
|
210 |
-
"text": transcription.text,
|
211 |
-
"speaker": -1
|
212 |
-
})
|
213 |
-
full_transcription += transcription.text if transcription else ""
|
214 |
buffer = online.get_buffer()
|
215 |
if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
|
216 |
buffer = ""
|
217 |
else:
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
})
|
224 |
-
sleep(1)
|
225 |
buffer = ''
|
226 |
|
227 |
if args.diarization:
|
228 |
await diarization.diarize(pcm_array)
|
229 |
-
end_attributed_speaker = diarization.
|
230 |
-
|
231 |
|
232 |
-
|
233 |
lines = []
|
234 |
last_end_diarized = 0
|
235 |
-
|
236 |
-
|
237 |
-
speaker = ch.get("speaker")
|
238 |
if args.diarization:
|
239 |
if speaker == -1 or speaker == 0:
|
240 |
-
if
|
241 |
speaker = previous_speaker
|
242 |
else:
|
243 |
speaker = 0
|
244 |
else:
|
245 |
-
last_end_diarized = max(
|
246 |
|
247 |
-
if speaker !=
|
248 |
lines.append(
|
249 |
{
|
250 |
"speaker": speaker,
|
251 |
-
"text":
|
252 |
-
"beg": format_time(
|
253 |
-
"end": format_time(
|
254 |
-
"diff": round(
|
255 |
}
|
256 |
)
|
257 |
-
|
258 |
else:
|
259 |
-
lines[-1]["text"] +=
|
260 |
-
lines[-1]["end"] = format_time(
|
261 |
-
lines[-1]["diff"] = round(
|
262 |
|
263 |
response = {"lines": lines, "buffer": buffer}
|
264 |
await websocket.send_json(response)
|
|
|
11 |
from fastapi.middleware.cors import CORSMiddleware
|
12 |
|
13 |
from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
|
14 |
+
from src.whisper_streaming.timed_objects import ASRToken
|
15 |
|
16 |
import math
|
17 |
import logging
|
|
|
48 |
parser.add_argument(
|
49 |
"--diarization",
|
50 |
type=bool,
|
51 |
+
default=True,
|
52 |
help="Whether to enable speaker diarization.",
|
53 |
)
|
54 |
|
|
|
158 |
full_transcription = ""
|
159 |
beg = time()
|
160 |
beg_loop = time()
|
161 |
+
tokens = []
|
162 |
+
end_attributed_speaker = 0
|
163 |
+
sep = online.asr.sep
|
164 |
|
165 |
while True:
|
166 |
try:
|
|
|
180 |
logger.warning("FFmpeg read timeout. Restarting...")
|
181 |
await restart_ffmpeg()
|
182 |
full_transcription = ""
|
|
|
183 |
beg = time()
|
184 |
continue # Skip processing and read from new process
|
185 |
|
|
|
204 |
if args.transcription:
|
205 |
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
|
206 |
online.insert_audio_chunk(pcm_array)
|
207 |
+
new_tokens = online.process_iter()
|
208 |
+
tokens.extend(new_tokens)
|
209 |
+
full_transcription += sep.join([t.text for t in new_tokens])
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
buffer = online.get_buffer()
|
211 |
if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
|
212 |
buffer = ""
|
213 |
else:
|
214 |
+
tokens.append(
|
215 |
+
ASRToken(
|
216 |
+
start = time() - beg_loop,
|
217 |
+
end = time() - beg_loop + 0.5))
|
218 |
+
sleep(0.5)
|
|
|
|
|
219 |
buffer = ''
|
220 |
|
221 |
if args.diarization:
|
222 |
await diarization.diarize(pcm_array)
|
223 |
+
end_attributed_speaker = diarization.assign_speakers_to_tokens(end_attributed_speaker, tokens)
|
|
|
224 |
|
225 |
+
previous_speaker = -10
|
226 |
lines = []
|
227 |
last_end_diarized = 0
|
228 |
+
for token in tokens:
|
229 |
+
speaker = token.speaker
|
|
|
230 |
if args.diarization:
|
231 |
if speaker == -1 or speaker == 0:
|
232 |
+
if token.end < end_attributed_speaker:
|
233 |
speaker = previous_speaker
|
234 |
else:
|
235 |
speaker = 0
|
236 |
else:
|
237 |
+
last_end_diarized = max(token.end, last_end_diarized)
|
238 |
|
239 |
+
if speaker != previous_speaker:
|
240 |
lines.append(
|
241 |
{
|
242 |
"speaker": speaker,
|
243 |
+
"text": token.text,
|
244 |
+
"beg": format_time(token.start),
|
245 |
+
"end": format_time(token.end),
|
246 |
+
"diff": round(token.end - last_end_diarized, 2)
|
247 |
}
|
248 |
)
|
249 |
+
previous_speaker = speaker
|
250 |
else:
|
251 |
+
lines[-1]["text"] += sep + token.text
|
252 |
+
lines[-1]["end"] = format_time(token.end)
|
253 |
+
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
254 |
|
255 |
response = {"lines": lines, "buffer": buffer}
|
256 |
await websocket.send_json(response)
|