hashhac commited on
Commit
36420ca
·
1 Parent(s): be00791

embeddings added

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -5,6 +5,7 @@ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5For
5
  import soundfile as sf
6
  import tempfile
7
  import os
 
8
 
9
  # Check if CUDA is available, otherwise use CPU
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -14,6 +15,10 @@ processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_asr")
14
  asr_model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr").to(device)
15
  tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
16
 
 
 
 
 
17
  # Function to convert speech to text
18
  def speech_to_text(audio_dict):
19
  # Extract the audio array from the dictionary
@@ -33,7 +38,10 @@ def speech_to_text(audio_dict):
33
  def text_to_speech(text):
34
  inputs = processor(text=text, return_tensors="pt").input_ids.to(device)
35
  with torch.no_grad():
36
- speech = tts_model.generate_speech(inputs)
 
 
 
37
  return speech
38
 
39
  # Gradio demo
 
5
  import soundfile as sf
6
  import tempfile
7
  import os
8
+ from datasets import load_dataset
9
 
10
  # Check if CUDA is available, otherwise use CPU
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
15
  asr_model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr").to(device)
16
  tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
17
 
18
+ # Load speaker embeddings
19
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
20
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device)
21
+
22
  # Function to convert speech to text
23
  def speech_to_text(audio_dict):
24
  # Extract the audio array from the dictionary
 
38
  def text_to_speech(text):
39
  inputs = processor(text=text, return_tensors="pt").input_ids.to(device)
40
  with torch.no_grad():
41
+ speech = tts_model.generate_speech(
42
+ inputs,
43
+ speaker_embeddings=speaker_embeddings
44
+ )
45
  return speech
46
 
47
  # Gradio demo