Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
from transformers import pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
from datasets import load_dataset | |
import soundfile as sf | |
import tempfile | |
import os | |
# Check if CUDA is available, otherwise use CPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load Whisper for ASR | |
print("Loading ASR model...") | |
asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-small", device=device) | |
# Load SpeechT5 for TTS | |
print("Loading TTS model...") | |
tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device) | |
# Load SpeechT5 vocoder (THIS WAS MISSING) | |
print("Loading vocoder...") | |
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device) | |
# Load speaker embeddings for TTS | |
print("Loading speaker embeddings...") | |
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device) | |
# Function to convert speech to text using Whisper | |
def speech_to_text(audio_data, sample_rate): | |
# Normalize audio data | |
audio_data = audio_data.flatten().astype(np.float32) / 32768.0 | |
# Process with Whisper | |
result = asr_pipeline({"raw": audio_data, "sampling_rate": sample_rate}) | |
return result["text"] | |
# Function to convert text to speech using SpeechT5 | |
def text_to_speech(text): | |
# Process text input | |
inputs = tts_processor(text=text, return_tensors="pt").to(device) | |
# Generate speech with speaker embeddings | |
with torch.no_grad(): | |
speech = tts_model.generate_speech( | |
inputs["input_ids"], | |
speaker_embeddings=speaker_embeddings | |
) | |
# Convert spectrogram to waveform using vocoder | |
waveform = vocoder(speech) | |
return waveform | |
# Gradio demo | |
def demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Voice Chatbot") | |
gr.Markdown("Simply speak into the microphone and get an audio response.") | |
audio_input = gr.Audio(sources=["microphone"], type="numpy", label="Speak") | |
audio_output = gr.Audio(label="Response", autoplay=True) | |
transcript_display = gr.Textbox(label="Conversation") | |
def process_audio(audio): | |
if audio is None: | |
return None, "No audio detected." | |
try: | |
# Get audio data | |
sample_rate, audio_data = audio | |
# Speech-to-text | |
transcript = speech_to_text(audio_data, sample_rate) | |
print(f"Transcribed: {transcript}") | |
# Generate response (for simplicity, echo the transcript) | |
response_text = transcript | |
print(f"Response: {response_text}") | |
# Text-to-speech | |
response_audio = text_to_speech(response_text) | |
# Save the response audio to a temporary file | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: | |
# Ensure audio is properly scaled | |
audio_np = response_audio.cpu().numpy() | |
# Normalize audio to avoid clipping | |
audio_np = audio_np / (np.max(np.abs(audio_np)) + 1e-8) * 0.9 | |
sf.write(temp_file.name, audio_np, 16000) | |
temp_filename = temp_file.name | |
# Read the audio file | |
audio_data, sample_rate = sf.read(temp_filename) | |
# Clean up the temporary file | |
os.unlink(temp_filename) | |
return (sample_rate, audio_data), f"You: {transcript}\nAssistant: {response_text}" | |
except Exception as e: | |
print(f"Error in process_audio: {e}") | |
import traceback | |
traceback.print_exc() | |
return None, f"Error processing audio: {str(e)}" | |
audio_input.change(process_audio, | |
inputs=[audio_input], | |
outputs=[audio_output, transcript_display]) | |
clear_btn = gr.Button("Clear Conversation") | |
clear_btn.click(lambda: (None, ""), outputs=[audio_output, transcript_display]) | |
demo.launch() | |
if __name__ == "__main__": | |
demo() |