hashhac commited on
Commit
ca1dafb
·
1 Parent(s): ca032b0
Files changed (1) hide show
  1. app.py +31 -26
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
- from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5ForSpeechToText
 
5
  import soundfile as sf
6
  import tempfile
7
  import os
@@ -9,36 +10,41 @@ import os
9
  # Check if CUDA is available, otherwise use CPU
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- # Load SpeechT5 models and processor
13
- processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_asr")
14
- asr_model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr").to(device)
 
 
 
 
15
  tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
16
 
17
- # Function to convert speech to text
18
- def speech_to_text(audio_dict):
19
- # Extract the audio array from the dictionary
20
- audio_array = audio_dict["array"]
21
-
22
- # Pass the audio array directly to the processor
23
- inputs = processor(audio=audio_array, sampling_rate=16000, return_tensors="pt").input_values.to(device)
24
-
25
- with torch.no_grad():
26
- logits = asr_model(inputs).logits
27
 
28
- predicted_ids = torch.argmax(logits, dim=-1)
29
- transcription = processor.batch_decode(predicted_ids)[0]
30
- return transcription
31
 
32
- # Function to convert text to speech
33
  def text_to_speech(text):
34
- inputs = processor(text=text, return_tensors="pt").input_ids.to(device)
35
- # Create dummy decoder input IDs (this is a simplification)
36
- decoder_input_ids = torch.zeros((1, 1), dtype=torch.long).to(device)
 
37
  with torch.no_grad():
38
  speech = tts_model.generate_speech(
39
- inputs,
40
- decoder_input_ids=decoder_input_ids
41
  )
 
42
  return speech
43
 
44
  # Gradio demo
@@ -55,12 +61,11 @@ def demo():
55
  if audio is None:
56
  return None, "No audio detected."
57
 
58
- # Convert audio to the correct format
59
  sample_rate, audio_data = audio
60
- audio_data = audio_data.flatten().astype(np.float32) / 32768.0 # Normalize to [-1.0, 1.0]
61
 
62
  # Speech-to-text
63
- transcript = speech_to_text({"array": audio_data, "sampling_rate": sample_rate})
64
  print(f"Transcribed: {transcript}")
65
 
66
  # Generate response (for simplicity, echo the transcript)
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
+ from transformers import pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech
5
+ from datasets import load_dataset
6
  import soundfile as sf
7
  import tempfile
8
  import os
 
10
  # Check if CUDA is available, otherwise use CPU
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Load Whisper for ASR (much more reliable than SpeechT5 for ASR)
14
+ print("Loading ASR model...")
15
+ asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-small", device=device)
16
+
17
+ # Load SpeechT5 for TTS
18
+ print("Loading TTS model...")
19
+ tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
20
  tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
21
 
22
+ # Load speaker embeddings for TTS
23
+ print("Loading speaker embeddings...")
24
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
25
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device)
26
+
27
+ # Function to convert speech to text using Whisper
28
+ def speech_to_text(audio_data, sample_rate):
29
+ # Normalize audio data
30
+ audio_data = audio_data.flatten().astype(np.float32) / 32768.0
 
31
 
32
+ # Process with Whisper
33
+ result = asr_pipeline({"raw": audio_data, "sampling_rate": sample_rate})
34
+ return result["text"]
35
 
36
+ # Function to convert text to speech using SpeechT5
37
  def text_to_speech(text):
38
+ # Process text input
39
+ inputs = tts_processor(text=text, return_tensors="pt").to(device)
40
+
41
+ # Generate speech with speaker embeddings
42
  with torch.no_grad():
43
  speech = tts_model.generate_speech(
44
+ inputs["input_ids"],
45
+ speaker_embeddings=speaker_embeddings
46
  )
47
+
48
  return speech
49
 
50
  # Gradio demo
 
61
  if audio is None:
62
  return None, "No audio detected."
63
 
64
+ # Get audio data
65
  sample_rate, audio_data = audio
 
66
 
67
  # Speech-to-text
68
+ transcript = speech_to_text(audio_data, sample_rate)
69
  print(f"Transcribed: {transcript}")
70
 
71
  # Generate response (for simplicity, echo the transcript)