qfuxa commited on
Commit
a163c7b
·
1 Parent(s): a14d9a0

Diarization : Uses a rx observer instead of diart attach_hooks method

Browse files
diarization/diarization_online.py CHANGED
@@ -2,16 +2,79 @@ import asyncio
2
  import re
3
  import threading
4
  import numpy as np
 
 
5
 
6
  from diart import SpeakerDiarization
7
  from diart.inference import StreamingInference
8
  from diart.sources import AudioSource
9
  from timed_objects import SpeakerSegment
 
 
 
 
 
 
10
 
11
  def extract_number(s: str) -> int:
12
  m = re.search(r'\d+', s)
13
  return int(m.group()) if m else None
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  class WebSocketAudioSource(AudioSource):
17
  """
@@ -34,57 +97,57 @@ class WebSocketAudioSource(AudioSource):
34
 
35
  def push_audio(self, chunk: np.ndarray):
36
  if not self._closed:
37
- self.stream.on_next(np.expand_dims(chunk, axis=0))
 
 
38
 
39
 
40
  class DiartDiarization:
41
- def __init__(self, sample_rate: int):
42
- self.processed_time = 0
43
- self.segment_speakers = []
44
- self.speakers_queue = asyncio.Queue()
45
- self.pipeline = SpeakerDiarization()
46
- self.source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
 
 
 
 
 
47
  self.inference = StreamingInference(
48
  pipeline=self.pipeline,
49
  source=self.source,
50
  do_plot=False,
51
  show_progress=False,
52
  )
53
- # Attache la fonction hook et démarre l'inférence en arrière-plan.
54
- self.inference.attach_hooks(self._diar_hook)
55
  asyncio.get_event_loop().run_in_executor(None, self.inference)
56
 
57
- def _diar_hook(self, result):
58
- annotation, audio = result
59
- if annotation._labels:
60
- for speaker, label in annotation._labels.items():
61
- start = label.segments_boundaries_[0]
62
- end = label.segments_boundaries_[-1]
63
- if end > self.processed_time:
64
- self.processed_time = end
65
- asyncio.create_task(self.speakers_queue.put(SpeakerSegment(
66
- speaker=speaker,
67
- start=start,
68
- end=end,
69
- )))
70
- else:
71
- dur = audio.extent.end
72
- if dur > self.processed_time:
73
- self.processed_time = dur
74
-
75
  async def diarize(self, pcm_array: np.ndarray):
76
- self.source.push_audio(pcm_array)
77
- self.segment_speakers.clear()
78
- while not self.speakers_queue.empty():
79
- self.segment_speakers.append(await self.speakers_queue.get())
 
 
 
 
80
 
81
  def close(self):
82
- self.source.close()
 
 
83
 
84
- def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
 
 
 
 
 
 
85
  for token in tokens:
86
- for segment in self.segment_speakers:
87
  if not (segment.end <= token.start or segment.start >= token.end):
88
  token.speaker = extract_number(segment.speaker) + 1
89
  end_attributed_speaker = max(token.end, end_attributed_speaker)
90
- return end_attributed_speaker
 
2
  import re
3
  import threading
4
  import numpy as np
5
+ import logging
6
+
7
 
8
  from diart import SpeakerDiarization
9
  from diart.inference import StreamingInference
10
  from diart.sources import AudioSource
11
  from timed_objects import SpeakerSegment
12
+ from diart.sources import MicrophoneAudioSource
13
+ from rx.core import Observer
14
+ from typing import Tuple, Any, List
15
+ from pyannote.core import Annotation
16
+
17
+ logger = logging.getLogger(__name__)
18
 
19
  def extract_number(s: str) -> int:
20
  m = re.search(r'\d+', s)
21
  return int(m.group()) if m else None
22
 
23
+ class DiarizationObserver(Observer):
24
+ """Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
25
+
26
+ def __init__(self):
27
+ self.speaker_segments = []
28
+ self.processed_time = 0
29
+ self.segment_lock = threading.Lock()
30
+
31
+ def on_next(self, value: Tuple[Annotation, Any]):
32
+ annotation, audio = value
33
+
34
+ logger.debug("\n--- New Diarization Result ---")
35
+
36
+ duration = audio.extent.end - audio.extent.start
37
+ logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)")
38
+ logger.debug(f"Audio shape: {audio.data.shape}")
39
+
40
+ with self.segment_lock:
41
+ if audio.extent.end > self.processed_time:
42
+ self.processed_time = audio.extent.end
43
+ if annotation and len(annotation._labels) > 0:
44
+ logger.debug("\nSpeaker segments:")
45
+ for speaker, label in annotation._labels.items():
46
+ for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
47
+ print(f" {speaker}: {start:.2f}s-{end:.2f}s")
48
+ self.speaker_segments.append(SpeakerSegment(
49
+ speaker=speaker,
50
+ start=start,
51
+ end=end
52
+ ))
53
+ else:
54
+ logger.debug("\nNo speakers detected in this segment")
55
+
56
+ def get_segments(self) -> List[SpeakerSegment]:
57
+ """Get a copy of the current speaker segments."""
58
+ with self.segment_lock:
59
+ return self.speaker_segments.copy()
60
+
61
+ def clear_old_segments(self, older_than: float = 30.0):
62
+ """Clear segments older than the specified time."""
63
+ with self.segment_lock:
64
+ current_time = self.processed_time
65
+ self.speaker_segments = [
66
+ segment for segment in self.speaker_segments
67
+ if current_time - segment.end < older_than
68
+ ]
69
+
70
+ def on_error(self, error):
71
+ """Handle an error in the stream."""
72
+ logger.debug(f"Error in diarization stream: {error}")
73
+
74
+ def on_completed(self):
75
+ """Handle the completion of the stream."""
76
+ logger.debug("Diarization stream completed")
77
+
78
 
79
  class WebSocketAudioSource(AudioSource):
80
  """
 
97
 
98
  def push_audio(self, chunk: np.ndarray):
99
  if not self._closed:
100
+ new_audio = np.expand_dims(chunk, axis=0)
101
+ logger.debug('Add new chunk with shape:', new_audio.shape)
102
+ self.stream.on_next(new_audio)
103
 
104
 
105
  class DiartDiarization:
106
+ def __init__(self, sample_rate: int, use_microphone: bool = False):
107
+ self.pipeline = SpeakerDiarization()
108
+ self.observer = DiarizationObserver()
109
+
110
+ if use_microphone:
111
+ self.source = MicrophoneAudioSource()
112
+ self.custom_source = None
113
+ else:
114
+ self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
115
+ self.source = self.custom_source
116
+
117
  self.inference = StreamingInference(
118
  pipeline=self.pipeline,
119
  source=self.source,
120
  do_plot=False,
121
  show_progress=False,
122
  )
123
+ self.inference.attach_observers(self.observer)
 
124
  asyncio.get_event_loop().run_in_executor(None, self.inference)
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  async def diarize(self, pcm_array: np.ndarray):
127
+ """
128
+ Process audio data for diarization.
129
+ Only used when working with WebSocketAudioSource.
130
+ """
131
+ if self.custom_source:
132
+ self.custom_source.push_audio(pcm_array)
133
+ self.observer.clear_old_segments()
134
+ return self.observer.get_segments()
135
 
136
  def close(self):
137
+ """Close the audio source."""
138
+ if self.custom_source:
139
+ self.custom_source.close()
140
 
141
+ def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
142
+ """
143
+ Assign speakers to tokens based on timing overlap with speaker segments.
144
+ Uses the segments collected by the observer.
145
+ """
146
+ segments = self.observer.get_segments()
147
+
148
  for token in tokens:
149
+ for segment in segments:
150
  if not (segment.end <= token.start or segment.start >= token.end):
151
  token.speaker = extract_number(segment.speaker) + 1
152
  end_attributed_speaker = max(token.end, end_attributed_speaker)
153
+ return end_attributed_speaker
timed_objects.py CHANGED
@@ -8,6 +8,7 @@ class TimedText:
8
  text: Optional[str] = ''
9
  speaker: Optional[int] = -1
10
  probability: Optional[float] = None
 
11
 
12
  @dataclass
13
  class ASRToken(TimedText):
 
8
  text: Optional[str] = ''
9
  speaker: Optional[int] = -1
10
  probability: Optional[float] = None
11
+ is_dummy: Optional[bool] = False
12
 
13
  @dataclass
14
  class ASRToken(TimedText):
whisper_fastapi_online_server.py CHANGED
@@ -49,7 +49,7 @@ parser.add_argument(
49
  parser.add_argument(
50
  "--confidence-validation",
51
  type=bool,
52
- default=True,
53
  help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
54
  )
55
 
@@ -110,9 +110,10 @@ class SharedState:
110
  current_time = time() - self.beg_loop
111
  dummy_token = ASRToken(
112
  start=current_time,
113
- end=current_time + 0.5,
114
- text="",
115
- speaker=-1
 
116
  )
117
  self.tokens.append(dummy_token)
118
 
@@ -275,14 +276,13 @@ async def results_formatter(shared_state, websocket):
275
  sep = state["sep"]
276
 
277
  # If diarization is enabled but no transcription, add dummy tokens periodically
278
- if not tokens and not args.transcription and args.diarization:
279
  await shared_state.add_dummy_token()
280
- # Re-fetch tokens after adding dummy
281
  state = await shared_state.get_current_state()
282
  tokens = state["tokens"]
283
-
284
  # Process tokens to create response
285
- previous_speaker = -10
286
  lines = []
287
  last_end_diarized = 0
288
  undiarized_text = []
 
49
  parser.add_argument(
50
  "--confidence-validation",
51
  type=bool,
52
+ default=False,
53
  help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
54
  )
55
 
 
110
  current_time = time() - self.beg_loop
111
  dummy_token = ASRToken(
112
  start=current_time,
113
+ end=current_time + 1,
114
+ text=".",
115
+ speaker=-1,
116
+ is_dummy=True
117
  )
118
  self.tokens.append(dummy_token)
119
 
 
276
  sep = state["sep"]
277
 
278
  # If diarization is enabled but no transcription, add dummy tokens periodically
279
+ if (not tokens or tokens[-1].is_dummy) and not args.transcription and args.diarization:
280
  await shared_state.add_dummy_token()
281
+ sleep(0.5)
282
  state = await shared_state.get_current_state()
283
  tokens = state["tokens"]
 
284
  # Process tokens to create response
285
+ previous_speaker = -1
286
  lines = []
287
  last_end_diarized = 0
288
  undiarized_text = []