Spaces:
Sleeping
Sleeping
File size: 4,560 Bytes
4fb650d ab25fef ca1dafb 6218f6a a70a34d 4fb650d ab25fef ca1dafb 6218f6a 4fb650d ab25fef ca1dafb be00791 ca1dafb 4fb650d ca1dafb 6218f6a ca1dafb 4fb650d 36420ca ca1dafb 36420ca ab25fef ca1dafb ab25fef 4fb650d 6218f6a 4fb650d ab25fef 4fb650d ab25fef 4fb650d 6218f6a 4fb650d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import gradio as gr
import numpy as np
import torch
from transformers import pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset
import soundfile as sf
import tempfile
import os
# Check if CUDA is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Whisper for ASR
print("Loading ASR model...")
asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-small", device=device)
# Load SpeechT5 for TTS
print("Loading TTS model...")
tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
# Load SpeechT5 vocoder (THIS WAS MISSING)
print("Loading vocoder...")
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
# Load speaker embeddings for TTS
print("Loading speaker embeddings...")
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device)
# Function to convert speech to text using Whisper
def speech_to_text(audio_data, sample_rate):
# Normalize audio data
audio_data = audio_data.flatten().astype(np.float32) / 32768.0
# Process with Whisper
result = asr_pipeline({"raw": audio_data, "sampling_rate": sample_rate})
return result["text"]
# Function to convert text to speech using SpeechT5
def text_to_speech(text):
# Process text input
inputs = tts_processor(text=text, return_tensors="pt").to(device)
# Generate speech with speaker embeddings
with torch.no_grad():
speech = tts_model.generate_speech(
inputs["input_ids"],
speaker_embeddings=speaker_embeddings
)
# Convert spectrogram to waveform using vocoder
waveform = vocoder(speech)
return waveform
# Gradio demo
def demo():
with gr.Blocks() as demo:
gr.Markdown("# Voice Chatbot")
gr.Markdown("Simply speak into the microphone and get an audio response.")
audio_input = gr.Audio(sources=["microphone"], type="numpy", label="Speak")
audio_output = gr.Audio(label="Response", autoplay=True)
transcript_display = gr.Textbox(label="Conversation")
def process_audio(audio):
if audio is None:
return None, "No audio detected."
try:
# Get audio data
sample_rate, audio_data = audio
# Speech-to-text
transcript = speech_to_text(audio_data, sample_rate)
print(f"Transcribed: {transcript}")
# Generate response (for simplicity, echo the transcript)
response_text = transcript
print(f"Response: {response_text}")
# Text-to-speech
response_audio = text_to_speech(response_text)
# Save the response audio to a temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
# Ensure audio is properly scaled
audio_np = response_audio.cpu().numpy()
# Normalize audio to avoid clipping
audio_np = audio_np / (np.max(np.abs(audio_np)) + 1e-8) * 0.9
sf.write(temp_file.name, audio_np, 16000)
temp_filename = temp_file.name
# Read the audio file
audio_data, sample_rate = sf.read(temp_filename)
# Clean up the temporary file
os.unlink(temp_filename)
return (sample_rate, audio_data), f"You: {transcript}\nAssistant: {response_text}"
except Exception as e:
print(f"Error in process_audio: {e}")
import traceback
traceback.print_exc()
return None, f"Error processing audio: {str(e)}"
audio_input.change(process_audio,
inputs=[audio_input],
outputs=[audio_output, transcript_display])
clear_btn = gr.Button("Clear Conversation")
clear_btn.click(lambda: (None, ""), outputs=[audio_output, transcript_display])
demo.launch()
if __name__ == "__main__":
demo() |