qfuxa commited on
Commit
f19ad6b
·
1 Parent(s): d920423

Number of speakers not anymore limited to 10; a speaker has been created for "being processed" (-1), and another one for no" speaker detected" (-2)

Browse files
Files changed (1) hide show
  1. src/diarization/diarization_online.py +48 -28
src/diarization/diarization_online.py CHANGED
@@ -5,6 +5,11 @@ from rx.subject import Subject
5
  import threading
6
  import numpy as np
7
  import asyncio
 
 
 
 
 
8
 
9
  class WebSocketAudioSource(AudioSource):
10
  """
@@ -44,37 +49,48 @@ def create_pipeline(SAMPLE_RATE):
44
  return inference, ws_source
45
 
46
 
47
- def init_diart(SAMPLE_RATE):
48
- inference, ws_source = create_pipeline(SAMPLE_RATE)
 
 
 
 
 
 
 
 
 
49
 
50
  def diar_hook(result):
51
  """
52
  Hook called each time Diart processes a chunk.
53
  result is (annotation, audio).
54
- We store the label of the last segment in 'current_speaker'.
55
  """
56
- global l_speakers
57
- l_speakers = []
58
  annotation, audio = result
59
- for speaker in annotation._labels:
60
- segments_beg = annotation._labels[speaker].segments_boundaries_[0]
61
- segments_end = annotation._labels[speaker].segments_boundaries_[-1]
62
- asyncio.create_task(
63
- l_speakers_queue.put({"speaker": speaker, "beg": segments_beg, "end": segments_end})
64
- )
 
 
 
 
 
 
 
65
 
66
- l_speakers_queue = asyncio.Queue()
67
  inference.attach_hooks(diar_hook)
68
-
69
- # Launch Diart in a background thread
70
  loop = asyncio.get_event_loop()
71
  diar_future = loop.run_in_executor(None, inference)
72
  return inference, l_speakers_queue, ws_source
73
 
74
-
75
- class DiartDiarization():
76
  def __init__(self, SAMPLE_RATE):
77
- self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE)
 
78
  self.segment_speakers = []
79
 
80
  async def diarize(self, pcm_array):
@@ -82,20 +98,21 @@ class DiartDiarization():
82
  self.segment_speakers = []
83
  while not self.l_speakers_queue.empty():
84
  self.segment_speakers.append(await self.l_speakers_queue.get())
85
-
86
  def close(self):
87
  self.ws_source.close()
88
 
89
-
90
  def assign_speakers_to_chunks(self, chunks):
91
  """
92
- Go through each chunk and see which speaker(s) overlap
93
- that chunk's time range in the Diart annotation.
94
- Then store the speaker label(s) (or choose the most overlapping).
95
- This modifies `chunks` in-place or returns a new list with assigned speakers.
 
 
96
  """
97
- if not self.segment_speakers:
98
- return chunks
99
 
100
  for segment in self.segment_speakers:
101
  seg_beg = segment["beg"]
@@ -104,7 +121,10 @@ class DiartDiarization():
104
  for ch in chunks:
105
  if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
106
  continue
107
- # We have overlap. Let's just pick the speaker (could be more precise in a more complex implementation)
108
- ch["speaker"] = speaker
 
 
 
109
 
110
- return chunks
 
5
  import threading
6
  import numpy as np
7
  import asyncio
8
+ import re
9
+
10
+ def extract_number(s):
11
+ match = re.search(r'\d+', s)
12
+ return int(match.group()) if match else None
13
 
14
  class WebSocketAudioSource(AudioSource):
15
  """
 
49
  return inference, ws_source
50
 
51
 
52
+ def init_diart(SAMPLE_RATE, diar_instance):
53
+ diar_pipeline = SpeakerDiarization()
54
+ ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE)
55
+ inference = StreamingInference(
56
+ pipeline=diar_pipeline,
57
+ source=ws_source,
58
+ do_plot=False,
59
+ show_progress=False,
60
+ )
61
+
62
+ l_speakers_queue = asyncio.Queue()
63
 
64
  def diar_hook(result):
65
  """
66
  Hook called each time Diart processes a chunk.
67
  result is (annotation, audio).
68
+ For each detected speaker segment, push its info to the queue and update processed_time.
69
  """
 
 
70
  annotation, audio = result
71
+ if annotation._labels:
72
+ for speaker in annotation._labels:
73
+ segments_beg = annotation._labels[speaker].segments_boundaries_[0]
74
+ segments_end = annotation._labels[speaker].segments_boundaries_[-1]
75
+ if segments_end > diar_instance.processed_time:
76
+ diar_instance.processed_time = segments_end
77
+ asyncio.create_task(
78
+ l_speakers_queue.put({"speaker": speaker, "beg": segments_beg, "end": segments_end})
79
+ )
80
+ else:
81
+ audio_duration = audio.extent.end
82
+ if audio_duration > diar_instance.processed_time:
83
+ diar_instance.processed_time = audio_duration
84
 
 
85
  inference.attach_hooks(diar_hook)
 
 
86
  loop = asyncio.get_event_loop()
87
  diar_future = loop.run_in_executor(None, inference)
88
  return inference, l_speakers_queue, ws_source
89
 
90
+ class DiartDiarization:
 
91
  def __init__(self, SAMPLE_RATE):
92
+ self.processed_time = 0
93
+ self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE, self)
94
  self.segment_speakers = []
95
 
96
  async def diarize(self, pcm_array):
 
98
  self.segment_speakers = []
99
  while not self.l_speakers_queue.empty():
100
  self.segment_speakers.append(await self.l_speakers_queue.get())
101
+
102
  def close(self):
103
  self.ws_source.close()
104
 
 
105
  def assign_speakers_to_chunks(self, chunks):
106
  """
107
+ For each chunk (a dict with keys "beg" and "end"), assign a speaker label.
108
+
109
+ - If a chunk overlaps with a detected speaker segment, assign that label.
110
+ - If the chunk's end time is within the processed time and no speaker was assigned,
111
+ mark it as "No speaker".
112
+ - If the chunk's time hasn't been fully processed yet, leave it (or mark as "Processing").
113
  """
114
+ for ch in chunks:
115
+ ch["speaker"] = ch.get("speaker", -1)
116
 
117
  for segment in self.segment_speakers:
118
  seg_beg = segment["beg"]
 
121
  for ch in chunks:
122
  if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
123
  continue
124
+ ch["speaker"] = extract_number(speaker) + 1
125
+ if self.processed_time > 0:
126
+ for ch in chunks:
127
+ if ch["end"] <= self.processed_time and ch["speaker"] == -1:
128
+ ch["speaker"] = -2
129
 
130
+ return chunks