File size: 3,836 Bytes
4fb650d
 
 
6218f6a
 
a70a34d
 
36420ca
4fb650d
 
 
 
6218f6a
 
 
 
4fb650d
36420ca
 
 
 
6218f6a
be00791
 
 
 
 
 
 
6218f6a
 
be00791
6218f6a
 
 
4fb650d
6218f6a
 
be00791
4fb650d
36420ca
 
 
 
6218f6a
4fb650d
6218f6a
4fb650d
 
 
 
 
 
 
 
 
 
 
 
 
6218f6a
 
 
4fb650d
 
a70a34d
6218f6a
4fb650d
6218f6a
 
4fb650d
 
6218f6a
 
 
 
 
 
 
 
 
 
4fb650d
6218f6a
 
4fb650d
6218f6a
4fb650d
 
6218f6a
 
4fb650d
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import numpy as np
import torch
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5ForSpeechToText
import soundfile as sf
import tempfile
import os
from datasets import load_dataset

# Check if CUDA is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load SpeechT5 models and processor
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_asr")
asr_model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr").to(device)
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)

# Load speaker embeddings
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device)

# Function to convert speech to text
def speech_to_text(audio_dict):
    # Extract the audio array from the dictionary
    audio_array = audio_dict["array"]
    
    # Pass the audio array directly to the processor
    inputs = processor(audio=audio_array, sampling_rate=16000, return_tensors="pt").input_values.to(device)
    
    with torch.no_grad():
        logits = asr_model(inputs).logits
    
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)[0]
    return transcription

# Function to convert text to speech
def text_to_speech(text):
    inputs = processor(text=text, return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        speech = tts_model.generate_speech(
            inputs, 
            speaker_embeddings=speaker_embeddings
        )
    return speech

# Gradio demo
def demo():
    with gr.Blocks() as demo:
        gr.Markdown("# Voice Chatbot")
        gr.Markdown("Simply speak into the microphone and get an audio response.")
        
        audio_input = gr.Audio(sources=["microphone"], type="numpy", label="Speak")
        audio_output = gr.Audio(label="Response", autoplay=True)
        transcript_display = gr.Textbox(label="Conversation")
        
        def process_audio(audio):
            if audio is None:
                return None, "No audio detected."
            
            # Convert audio to the correct format
            sample_rate, audio_data = audio
            audio_data = audio_data.flatten().astype(np.float32) / 32768.0  # Normalize to [-1.0, 1.0]
            
            # Speech-to-text
            transcript = speech_to_text({"array": audio_data, "sampling_rate": sample_rate})
            print(f"Transcribed: {transcript}")
            
            # Generate response (for simplicity, echo the transcript)
            response_text = transcript
            print(f"Response: {response_text}")
            
            # Text-to-speech
            response_audio = text_to_speech(response_text)
            
            # Save the response audio to a temporary file
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
                sf.write(temp_file.name, response_audio.cpu().numpy(), 16000)
                temp_filename = temp_file.name
            
            # Read the audio file
            audio_data, sample_rate = sf.read(temp_filename)
            
            # Clean up the temporary file
            os.unlink(temp_filename)
            
            return (sample_rate, audio_data), f"You: {transcript}\nAssistant: {response_text}"
        
        audio_input.change(process_audio, 
                           inputs=[audio_input], 
                           outputs=[audio_output, transcript_display])
        
        clear_btn = gr.Button("Clear Conversation")
        clear_btn.click(lambda: (None, ""), outputs=[audio_output, transcript_display])
    
    demo.launch()

if __name__ == "__main__":
    demo()