Voice_Assistant / app.py
quentinbch's picture
maj fonction synthesise()
3420952
from transformers import pipeline
import torch
from transformers.pipelines.audio_utils import ffmpeg_microphone_live
from huggingface_hub import HfFolder, InferenceClient
import requests
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset
import sounddevice as sd
import sys
import os
from dotenv import load_dotenv
import gradio as gr
import warnings
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
warnings.filterwarnings("ignore",
message="At least one mel filter has all zero values.*",
category=UserWarning)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
classifier = pipeline(
"audio-classification",
model="MIT/ast-finetuned-speech-commands-v2",
device=device
)
def launch_fn(wake_word="marvin", prob_threshold=0.5, chunk_length_s=2.0, stream_chunk_s=0.25, debug=False):
if wake_word not in classifier.model.config.label2id.keys():
raise ValueError(
f"Wake word {wake_word} not in set of valid class labels, pick a wake word in the set {classifier.model.config.label2id.keys()}."
)
sampling_rate = classifier.feature_extractor.sampling_rate
mic = ffmpeg_microphone_live(
sampling_rate=sampling_rate,
chunk_length_s=chunk_length_s,
stream_chunk_s=stream_chunk_s,
)
print("Listening for wake word...")
for prediction in classifier(mic):
prediction = prediction[0]
if debug:
print(prediction)
if prediction["label"] == wake_word:
if prediction["score"] > prob_threshold:
return True
transcriber = pipeline(
"automatic-speech-recognition", model="openai/whisper-base.en", device=device
)
def transcribe(chunk_length_s=5.0, stream_chunk_s=1.0):
sampling_rate = transcriber.feature_extractor.sampling_rate
mic = ffmpeg_microphone_live(
sampling_rate=sampling_rate,
chunk_length_s=chunk_length_s,
stream_chunk_s=stream_chunk_s,
)
print("Start speaking...")
for item in transcriber(mic, generate_kwargs={"max_new_tokens": 128}):
sys.stdout.write("\033[K")
print(item["text"], end="\r")
if not item["partial"][0]:
break
return item["text"]
client = InferenceClient(
provider="fireworks-ai",
api_key=HF_TOKEN
)
def query(text, model_id="meta-llama/Llama-3.1-8B-Instruct"):
try:
completion = client.chat.completions.create(
model=model_id,
messages=[{"role": "user", "content": text}]
)
return completion.choices[0].message.content
except Exception as e:
print(f"Erreur: {str(e)}")
return None
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
def synthesise(text):
input_ids = processor(text=text, return_tensors="pt")["input_ids"]
try:
speech = model.generate_speech(
input_ids.to(device),
speaker_embeddings.to(device),
vocoder=vocoder
)
return speech.cpu()
except Exception as e:
print(f"Erreur lors de la synthèse vocale : {e}")
return None
# launch_fn(debug=True)
# transcription = transcribe()
# response = query(transcription)
# audio = synthesise(response)
#
# sd.play(audio.numpy(), 16000)
# sd.wait()
# Interface Gradio
def assistant_vocal_interface():
launch_fn(debug=True)
transcription = transcribe()
response = query(transcription)
audio = synthesise(response)
return transcription, response, (16000, audio.numpy())
with gr.Blocks(title="Assistant Vocal") as demo:
gr.Markdown("## Assistant vocal : détection, transcription, génération et synthèse")
start_btn = gr.Button("Démarrer l'assistant")
transcription_box = gr.Textbox(label="Transcription")
response_box = gr.Textbox(label="Réponse IA")
audio_output = gr.Audio(label="Synthèse vocale", type="numpy", autoplay=True)
start_btn.click(
assistant_vocal_interface,
inputs=[],
outputs=[transcription_box, response_box, audio_output]
)
demo.launch(share=True)