Number of speakers not anymore limited to 10; a speaker has been created for "being processed" (-1), and another one for no" speaker detected" (-2)
Browse files
src/diarization/diarization_online.py
CHANGED
@@ -5,6 +5,11 @@ from rx.subject import Subject
|
|
5 |
import threading
|
6 |
import numpy as np
|
7 |
import asyncio
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
class WebSocketAudioSource(AudioSource):
|
10 |
"""
|
@@ -44,37 +49,48 @@ def create_pipeline(SAMPLE_RATE):
|
|
44 |
return inference, ws_source
|
45 |
|
46 |
|
47 |
-
def init_diart(SAMPLE_RATE):
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
def diar_hook(result):
|
51 |
"""
|
52 |
Hook called each time Diart processes a chunk.
|
53 |
result is (annotation, audio).
|
54 |
-
|
55 |
"""
|
56 |
-
global l_speakers
|
57 |
-
l_speakers = []
|
58 |
annotation, audio = result
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
l_speakers_queue = asyncio.Queue()
|
67 |
inference.attach_hooks(diar_hook)
|
68 |
-
|
69 |
-
# Launch Diart in a background thread
|
70 |
loop = asyncio.get_event_loop()
|
71 |
diar_future = loop.run_in_executor(None, inference)
|
72 |
return inference, l_speakers_queue, ws_source
|
73 |
|
74 |
-
|
75 |
-
class DiartDiarization():
|
76 |
def __init__(self, SAMPLE_RATE):
|
77 |
-
self.
|
|
|
78 |
self.segment_speakers = []
|
79 |
|
80 |
async def diarize(self, pcm_array):
|
@@ -82,20 +98,21 @@ class DiartDiarization():
|
|
82 |
self.segment_speakers = []
|
83 |
while not self.l_speakers_queue.empty():
|
84 |
self.segment_speakers.append(await self.l_speakers_queue.get())
|
85 |
-
|
86 |
def close(self):
|
87 |
self.ws_source.close()
|
88 |
|
89 |
-
|
90 |
def assign_speakers_to_chunks(self, chunks):
|
91 |
"""
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
96 |
"""
|
97 |
-
|
98 |
-
|
99 |
|
100 |
for segment in self.segment_speakers:
|
101 |
seg_beg = segment["beg"]
|
@@ -104,7 +121,10 @@ class DiartDiarization():
|
|
104 |
for ch in chunks:
|
105 |
if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
|
106 |
continue
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
109 |
|
110 |
-
return chunks
|
|
|
5 |
import threading
|
6 |
import numpy as np
|
7 |
import asyncio
|
8 |
+
import re
|
9 |
+
|
10 |
+
def extract_number(s):
|
11 |
+
match = re.search(r'\d+', s)
|
12 |
+
return int(match.group()) if match else None
|
13 |
|
14 |
class WebSocketAudioSource(AudioSource):
|
15 |
"""
|
|
|
49 |
return inference, ws_source
|
50 |
|
51 |
|
52 |
+
def init_diart(SAMPLE_RATE, diar_instance):
|
53 |
+
diar_pipeline = SpeakerDiarization()
|
54 |
+
ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE)
|
55 |
+
inference = StreamingInference(
|
56 |
+
pipeline=diar_pipeline,
|
57 |
+
source=ws_source,
|
58 |
+
do_plot=False,
|
59 |
+
show_progress=False,
|
60 |
+
)
|
61 |
+
|
62 |
+
l_speakers_queue = asyncio.Queue()
|
63 |
|
64 |
def diar_hook(result):
|
65 |
"""
|
66 |
Hook called each time Diart processes a chunk.
|
67 |
result is (annotation, audio).
|
68 |
+
For each detected speaker segment, push its info to the queue and update processed_time.
|
69 |
"""
|
|
|
|
|
70 |
annotation, audio = result
|
71 |
+
if annotation._labels:
|
72 |
+
for speaker in annotation._labels:
|
73 |
+
segments_beg = annotation._labels[speaker].segments_boundaries_[0]
|
74 |
+
segments_end = annotation._labels[speaker].segments_boundaries_[-1]
|
75 |
+
if segments_end > diar_instance.processed_time:
|
76 |
+
diar_instance.processed_time = segments_end
|
77 |
+
asyncio.create_task(
|
78 |
+
l_speakers_queue.put({"speaker": speaker, "beg": segments_beg, "end": segments_end})
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
audio_duration = audio.extent.end
|
82 |
+
if audio_duration > diar_instance.processed_time:
|
83 |
+
diar_instance.processed_time = audio_duration
|
84 |
|
|
|
85 |
inference.attach_hooks(diar_hook)
|
|
|
|
|
86 |
loop = asyncio.get_event_loop()
|
87 |
diar_future = loop.run_in_executor(None, inference)
|
88 |
return inference, l_speakers_queue, ws_source
|
89 |
|
90 |
+
class DiartDiarization:
|
|
|
91 |
def __init__(self, SAMPLE_RATE):
|
92 |
+
self.processed_time = 0
|
93 |
+
self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE, self)
|
94 |
self.segment_speakers = []
|
95 |
|
96 |
async def diarize(self, pcm_array):
|
|
|
98 |
self.segment_speakers = []
|
99 |
while not self.l_speakers_queue.empty():
|
100 |
self.segment_speakers.append(await self.l_speakers_queue.get())
|
101 |
+
|
102 |
def close(self):
|
103 |
self.ws_source.close()
|
104 |
|
|
|
105 |
def assign_speakers_to_chunks(self, chunks):
|
106 |
"""
|
107 |
+
For each chunk (a dict with keys "beg" and "end"), assign a speaker label.
|
108 |
+
|
109 |
+
- If a chunk overlaps with a detected speaker segment, assign that label.
|
110 |
+
- If the chunk's end time is within the processed time and no speaker was assigned,
|
111 |
+
mark it as "No speaker".
|
112 |
+
- If the chunk's time hasn't been fully processed yet, leave it (or mark as "Processing").
|
113 |
"""
|
114 |
+
for ch in chunks:
|
115 |
+
ch["speaker"] = ch.get("speaker", -1)
|
116 |
|
117 |
for segment in self.segment_speakers:
|
118 |
seg_beg = segment["beg"]
|
|
|
121 |
for ch in chunks:
|
122 |
if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
|
123 |
continue
|
124 |
+
ch["speaker"] = extract_number(speaker) + 1
|
125 |
+
if self.processed_time > 0:
|
126 |
+
for ch in chunks:
|
127 |
+
if ch["end"] <= self.processed_time and ch["speaker"] == -1:
|
128 |
+
ch["speaker"] = -2
|
129 |
|
130 |
+
return chunks
|