qfuxa commited on
Commit
28f51dc
·
1 Parent(s): b0d49ce

DiartDiarization now uses SpeakerSegment

Browse files
src/diarization/diarization_online.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  from diart import SpeakerDiarization
7
  from diart.inference import StreamingInference
8
  from diart.sources import AudioSource
9
-
10
 
11
  def extract_number(s: str) -> int:
12
  m = re.search(r'\d+', s)
@@ -58,15 +58,15 @@ class DiartDiarization:
58
  annotation, audio = result
59
  if annotation._labels:
60
  for speaker, label in annotation._labels.items():
61
- beg = label.segments_boundaries_[0]
62
  end = label.segments_boundaries_[-1]
63
  if end > self.processed_time:
64
  self.processed_time = end
65
- asyncio.create_task(self.speakers_queue.put({
66
- "speaker": speaker,
67
- "beg": beg,
68
- "end": end
69
- }))
70
  else:
71
  dur = audio.extent.end
72
  if dur > self.processed_time:
@@ -84,7 +84,7 @@ class DiartDiarization:
84
  def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
85
  for token in tokens:
86
  for segment in self.segment_speakers:
87
- if not (segment["end"] <= token.start or segment["beg"] >= token.end):
88
- token.speaker = extract_number(segment["speaker"]) + 1
89
  end_attributed_speaker = max(token.end, end_attributed_speaker)
90
  return end_attributed_speaker
 
6
  from diart import SpeakerDiarization
7
  from diart.inference import StreamingInference
8
  from diart.sources import AudioSource
9
+ from src.whisper_streaming.timed_objects import SpeakerSegment
10
 
11
  def extract_number(s: str) -> int:
12
  m = re.search(r'\d+', s)
 
58
  annotation, audio = result
59
  if annotation._labels:
60
  for speaker, label in annotation._labels.items():
61
+ start = label.segments_boundaries_[0]
62
  end = label.segments_boundaries_[-1]
63
  if end > self.processed_time:
64
  self.processed_time = end
65
+ asyncio.create_task(self.speakers_queue.put(SpeakerSegment(
66
+ speaker=speaker,
67
+ start=start,
68
+ end=end,
69
+ )))
70
  else:
71
  dur = audio.extent.end
72
  if dur > self.processed_time:
 
84
  def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
85
  for token in tokens:
86
  for segment in self.segment_speakers:
87
+ if not (segment.end <= token.start or segment.start >= token.end):
88
+ token.speaker = extract_number(segment.speaker) + 1
89
  end_attributed_speaker = max(token.end, end_attributed_speaker)
90
  return end_attributed_speaker