File size: 9,902 Bytes
5806e12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
import os
from typing import Callable, Text, Union
from typing import Optional
import numpy as np
import torch
from pyannote.audio import Model
from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines import VoiceActivityDetection
from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.core import Annotation, SlidingWindowFeature
from pyannote.core import Segment
from ..diarize import Segment as SegmentX
from .vad import Vad
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
model_dir = torch.hub._get_torch_home()
main_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.makedirs(model_dir, exist_ok = True)
if model_fp is None:
# Dynamically resolve the path to the model file
model_fp = os.path.join(main_dir, "assets", "pytorch_model.bin")
model_fp = os.path.abspath(model_fp) # Ensure the path is absolute
else:
model_fp = os.path.abspath(model_fp) # Ensure any provided path is absolute
# Check if the resolved model file exists
if not os.path.exists(model_fp):
raise FileNotFoundError(f"Model file not found at {model_fp}")
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
raise RuntimeError(f"{model_fp} exists and is not a regular file")
model_bytes = open(model_fp, "rb").read()
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
hyperparameters = {"onset": vad_onset,
"offset": vad_offset,
"min_duration_on": 0.1,
"min_duration_off": 0.1}
vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
vad_pipeline.instantiate(hyperparameters)
return vad_pipeline
class Binarize:
"""Binarize detection scores using hysteresis thresholding, with min-cut operation
to ensure not segments are longer than max_duration.
Parameters
----------
onset : float, optional
Onset threshold. Defaults to 0.5.
offset : float, optional
Offset threshold. Defaults to `onset`.
min_duration_on : float, optional
Remove active regions shorter than that many seconds. Defaults to 0s.
min_duration_off : float, optional
Fill inactive regions shorter than that many seconds. Defaults to 0s.
pad_onset : float, optional
Extend active regions by moving their start time by that many seconds.
Defaults to 0s.
pad_offset : float, optional
Extend active regions by moving their end time by that many seconds.
Defaults to 0s.
max_duration: float
The maximum length of an active segment, divides segment at timestamp with lowest score.
Reference
---------
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
RNN-based Voice Activity Detection", InterSpeech 2015.
Modified by Max Bain to include WhisperX's min-cut operation
https://arxiv.org/abs/2303.00747
Pyannote-audio
"""
def __init__(
self,
onset: float = 0.5,
offset: Optional[float] = None,
min_duration_on: float = 0.0,
min_duration_off: float = 0.0,
pad_onset: float = 0.0,
pad_offset: float = 0.0,
max_duration: float = float('inf')
):
super().__init__()
self.onset = onset
self.offset = offset or onset
self.pad_onset = pad_onset
self.pad_offset = pad_offset
self.min_duration_on = min_duration_on
self.min_duration_off = min_duration_off
self.max_duration = max_duration
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
"""Binarize detection scores
Parameters
----------
scores : SlidingWindowFeature
Detection scores.
Returns
-------
active : Annotation
Binarized scores.
"""
num_frames, num_classes = scores.data.shape
frames = scores.sliding_window
timestamps = [frames[i].middle for i in range(num_frames)]
# annotation meant to store 'active' regions
active = Annotation()
for k, k_scores in enumerate(scores.data.T):
label = k if scores.labels is None else scores.labels[k]
# initial state
start = timestamps[0]
is_active = k_scores[0] > self.onset
curr_scores = [k_scores[0]]
curr_timestamps = [start]
t = start
for t, y in zip(timestamps[1:], k_scores[1:]):
# currently active
if is_active:
curr_duration = t - start
if curr_duration > self.max_duration:
search_after = len(curr_scores) // 2
# divide segment
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
min_score_t = curr_timestamps[min_score_div_idx]
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
active[region, k] = label
start = curr_timestamps[min_score_div_idx]
curr_scores = curr_scores[min_score_div_idx + 1:]
curr_timestamps = curr_timestamps[min_score_div_idx + 1:]
# switching from active to inactive
elif y < self.offset:
region = Segment(start - self.pad_onset, t + self.pad_offset)
active[region, k] = label
start = t
is_active = False
curr_scores = []
curr_timestamps = []
curr_scores.append(y)
curr_timestamps.append(t)
# currently inactive
else:
# switching from inactive to active
if y > self.onset:
start = t
is_active = True
# if active at the end, add final region
if is_active:
region = Segment(start - self.pad_onset, t + self.pad_offset)
active[region, k] = label
# because of padding, some active regions might be overlapping: merge them.
# also: fill same speaker gaps shorter than min_duration_off
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
if self.max_duration < float("inf"):
raise NotImplementedError(f"This would break current max_duration param")
active = active.support(collar=self.min_duration_off)
# remove tracks shorter than min_duration_on
if self.min_duration_on > 0:
for segment, track in list(active.itertracks()):
if segment.duration < self.min_duration_on:
del active[segment, track]
return active
class VoiceActivitySegmentation(VoiceActivityDetection):
def __init__(
self,
segmentation: PipelineModel = "pyannote/segmentation",
fscore: bool = False,
use_auth_token: Union[Text, None] = None,
**inference_kwargs,
):
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
"""Apply voice activity detection
Parameters
----------
file : AudioFile
Processed file.
hook : callable, optional
Hook called after each major step of the pipeline with the following
signature: hook("step_name", step_artefact, file=file)
Returns
-------
speech : Annotation
Speech regions.
"""
# setup hook (e.g. for debugging purposes)
hook = self.setup_hook(file, hook=hook)
# apply segmentation model (only if needed)
# output shape is (num_chunks, num_frames, 1)
if self.training:
if self.CACHED_SEGMENTATION in file:
segmentations = file[self.CACHED_SEGMENTATION]
else:
segmentations = self._segmentation(file)
file[self.CACHED_SEGMENTATION] = segmentations
else:
segmentations: SlidingWindowFeature = self._segmentation(file)
return segmentations
class Pyannote(Vad):
def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
print(">>Performing voice activity detection using Pyannote...")
super().__init__(kwargs['vad_onset'])
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
def __call__(self, audio: AudioFile, **kwargs):
return self.vad_pipeline(audio)
@staticmethod
def preprocess_audio(audio):
return torch.from_numpy(audio).unsqueeze(0)
@staticmethod
def merge_chunks(segments,
chunk_size,
onset: float = 0.5,
offset: Optional[float] = None,
):
assert chunk_size > 0
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
segments = binarize(segments)
segments_list = []
for speech_turn in segments.get_timeline():
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
if len(segments_list) == 0:
print("No active speech found in audio")
return []
assert segments_list, "segments_list is empty."
return Vad.merge_chunks(segments_list, chunk_size, onset, offset)
|