|
import os
|
|
import sys
|
|
import torch
|
|
|
|
import numpy as np
|
|
import torch.nn.functional as F
|
|
|
|
from functools import cached_property
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from main.library.speaker_diarization.speechbrain import EncoderClassifier
|
|
|
|
class BaseInference:
|
|
pass
|
|
|
|
class SpeechBrainPretrainedSpeakerEmbedding(BaseInference):
|
|
def __init__(self, embedding = "assets/models/speaker_diarization/models/speechbrain", device = None):
|
|
super().__init__()
|
|
|
|
self.embedding = embedding
|
|
self.device = device or torch.device("cpu")
|
|
self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": self.device})
|
|
|
|
def to(self, device):
|
|
if not isinstance(device, torch.device): raise TypeError
|
|
|
|
self.classifier_ = EncoderClassifier.from_hparams(source=self.embedding, run_opts={"device": device})
|
|
self.device = device
|
|
return self
|
|
|
|
@cached_property
|
|
def sample_rate(self):
|
|
return self.classifier_.audio_normalizer.sample_rate
|
|
|
|
@cached_property
|
|
def dimension(self):
|
|
*_, dimension = self.classifier_.encode_batch(torch.rand(1, 16000).to(self.device)).shape
|
|
return dimension
|
|
|
|
@cached_property
|
|
def metric(self):
|
|
return "cosine"
|
|
|
|
@cached_property
|
|
def min_num_samples(self):
|
|
with torch.inference_mode():
|
|
lower, upper = 2, round(0.5 * self.sample_rate)
|
|
middle = (lower + upper) // 2
|
|
|
|
while lower + 1 < upper:
|
|
try:
|
|
_ = self.classifier_.encode_batch(torch.randn(1, middle).to(self.device))
|
|
upper = middle
|
|
except RuntimeError:
|
|
lower = middle
|
|
|
|
middle = (lower + upper) // 2
|
|
|
|
return upper
|
|
|
|
def __call__(self, waveforms, masks = None):
|
|
batch_size, num_channels, num_samples = waveforms.shape
|
|
assert num_channels == 1
|
|
|
|
waveforms = waveforms.squeeze(dim=1)
|
|
|
|
if masks is None:
|
|
signals = waveforms.squeeze(dim=1)
|
|
wav_lens = signals.shape[1] * torch.ones(batch_size)
|
|
else:
|
|
batch_size_masks, _ = masks.shape
|
|
assert batch_size == batch_size_masks
|
|
|
|
imasks = F.interpolate(masks.unsqueeze(dim=1), size=num_samples, mode="nearest").squeeze(dim=1) > 0.5
|
|
signals = pad_sequence([waveform[imask].contiguous() for waveform, imask in zip(waveforms, imasks)], batch_first=True)
|
|
wav_lens = imasks.sum(dim=1)
|
|
|
|
max_len = wav_lens.max()
|
|
if max_len < self.min_num_samples: return np.nan * np.zeros((batch_size, self.dimension))
|
|
|
|
too_short = wav_lens < self.min_num_samples
|
|
wav_lens = wav_lens / max_len
|
|
wav_lens[too_short] = 1.0
|
|
|
|
embeddings = (self.classifier_.encode_batch(signals, wav_lens=wav_lens).squeeze(dim=1).cpu().numpy())
|
|
embeddings[too_short.cpu().numpy()] = np.nan
|
|
|
|
return embeddings |