Commit
·
8e59fb5
1
Parent(s):
d1bb614
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import torch
|
|
4 |
from datasets import load_dataset
|
5 |
|
6 |
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor, pipeline
|
7 |
-
|
8 |
|
9 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
10 |
|
@@ -25,6 +25,23 @@ vocoder = SpeechT5HifiGan.from_pretrained("sanchit-gandhi/speecht5_tts_vox_nl").
|
|
25 |
# embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
|
26 |
# speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
dataset_nl = load_dataset("facebook/voxpopuli", "nl", split="train", streaming=True)
|
29 |
data_list = []
|
30 |
speaker_embeddings_list = []
|
|
|
4 |
from datasets import load_dataset
|
5 |
|
6 |
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor, pipeline
|
7 |
+
from speechbrain.pretrained import EncoderClassifier
|
8 |
|
9 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
10 |
|
|
|
25 |
# embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
|
26 |
# speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
|
27 |
|
28 |
+
spk_model_name = "speechbrain/spkrec-xvect-voxceleb"
|
29 |
+
|
30 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
+
speaker_model = EncoderClassifier.from_hparams(
|
32 |
+
source=spk_model_name,
|
33 |
+
run_opts={"device": device},
|
34 |
+
savedir=os.path.join("/tmp", spk_model_name),
|
35 |
+
)
|
36 |
+
|
37 |
+
def create_speaker_embedding(waveform):
|
38 |
+
with torch.no_grad():
|
39 |
+
speaker_embeddings = speaker_model.encode_batch(torch.tensor(waveform))
|
40 |
+
speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings, dim=2)
|
41 |
+
speaker_embeddings = speaker_embeddings.squeeze().cpu().numpy()
|
42 |
+
return speaker_embeddings
|
43 |
+
|
44 |
+
|
45 |
dataset_nl = load_dataset("facebook/voxpopuli", "nl", split="train", streaming=True)
|
46 |
data_list = []
|
47 |
speaker_embeddings_list = []
|