DiartDiarization now uses SpeakerSegment
Browse files
src/diarization/diarization_online.py
CHANGED
@@ -6,7 +6,7 @@ import numpy as np
|
|
6 |
from diart import SpeakerDiarization
|
7 |
from diart.inference import StreamingInference
|
8 |
from diart.sources import AudioSource
|
9 |
-
|
10 |
|
11 |
def extract_number(s: str) -> int:
|
12 |
m = re.search(r'\d+', s)
|
@@ -58,15 +58,15 @@ class DiartDiarization:
|
|
58 |
annotation, audio = result
|
59 |
if annotation._labels:
|
60 |
for speaker, label in annotation._labels.items():
|
61 |
-
|
62 |
end = label.segments_boundaries_[-1]
|
63 |
if end > self.processed_time:
|
64 |
self.processed_time = end
|
65 |
-
asyncio.create_task(self.speakers_queue.put(
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
else:
|
71 |
dur = audio.extent.end
|
72 |
if dur > self.processed_time:
|
@@ -84,7 +84,7 @@ class DiartDiarization:
|
|
84 |
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
|
85 |
for token in tokens:
|
86 |
for segment in self.segment_speakers:
|
87 |
-
if not (segment
|
88 |
-
token.speaker = extract_number(segment
|
89 |
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
90 |
return end_attributed_speaker
|
|
|
6 |
from diart import SpeakerDiarization
|
7 |
from diart.inference import StreamingInference
|
8 |
from diart.sources import AudioSource
|
9 |
+
from src.whisper_streaming.timed_objects import SpeakerSegment
|
10 |
|
11 |
def extract_number(s: str) -> int:
|
12 |
m = re.search(r'\d+', s)
|
|
|
58 |
annotation, audio = result
|
59 |
if annotation._labels:
|
60 |
for speaker, label in annotation._labels.items():
|
61 |
+
start = label.segments_boundaries_[0]
|
62 |
end = label.segments_boundaries_[-1]
|
63 |
if end > self.processed_time:
|
64 |
self.processed_time = end
|
65 |
+
asyncio.create_task(self.speakers_queue.put(SpeakerSegment(
|
66 |
+
speaker=speaker,
|
67 |
+
start=start,
|
68 |
+
end=end,
|
69 |
+
)))
|
70 |
else:
|
71 |
dur = audio.extent.end
|
72 |
if dur > self.processed_time:
|
|
|
84 |
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
|
85 |
for token in tokens:
|
86 |
for segment in self.segment_speakers:
|
87 |
+
if not (segment.end <= token.start or segment.start >= token.end):
|
88 |
+
token.speaker = extract_number(segment.speaker) + 1
|
89 |
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
90 |
return end_attributed_speaker
|