quentinbch commited on
Commit
31caba0
·
1 Parent(s): f177d8d

ajout fichier main.py

Browse files
Files changed (2) hide show
  1. .gitignore +4 -0
  2. main.py +142 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .env
2
+ .venv
3
+ __pycache__/
4
+ .idea
main.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ import torch
3
+ from transformers.pipelines.audio_utils import ffmpeg_microphone_live
4
+ from huggingface_hub import HfFolder, InferenceClient
5
+ import requests
6
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
7
+ from datasets import load_dataset
8
+ import sounddevice as sd
9
+ import sys
10
+ import os
11
+ from dotenv import load_dotenv
12
+ import gradio as gr
13
+ import warnings
14
+
15
+ load_dotenv()
16
+ HF_TOKEN = os.getenv("HF_TOKEN")
17
+
18
+ warnings.filterwarnings("ignore",
19
+ message="At least one mel filter has all zero values.*",
20
+ category=UserWarning)
21
+
22
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
23
+ classifier = pipeline(
24
+ "audio-classification",
25
+ model="MIT/ast-finetuned-speech-commands-v2",
26
+ device=device
27
+ )
28
+
29
+ def launch_fn(wake_word="marvin", prob_threshold=0.5, chunk_length_s=2.0, stream_chunk_s=0.25, debug=False):
30
+ if wake_word not in classifier.model.config.label2id.keys():
31
+ raise ValueError(
32
+ f"Wake word {wake_word} not in set of valid class labels, pick a wake word in the set {classifier.model.config.label2id.keys()}."
33
+ )
34
+
35
+ sampling_rate = classifier.feature_extractor.sampling_rate
36
+
37
+ mic = ffmpeg_microphone_live(
38
+ sampling_rate=sampling_rate,
39
+ chunk_length_s=chunk_length_s,
40
+ stream_chunk_s=stream_chunk_s,
41
+ )
42
+
43
+ print("Listening for wake word...")
44
+ for prediction in classifier(mic):
45
+ prediction = prediction[0]
46
+ if debug:
47
+ print(prediction)
48
+ if prediction["label"] == wake_word:
49
+ if prediction["score"] > prob_threshold:
50
+ return True
51
+
52
+ transcriber = pipeline(
53
+ "automatic-speech-recognition", model="openai/whisper-base.en", device=device
54
+ )
55
+
56
+ def transcribe(chunk_length_s=5.0, stream_chunk_s=1.0):
57
+ sampling_rate = transcriber.feature_extractor.sampling_rate
58
+
59
+ mic = ffmpeg_microphone_live(
60
+ sampling_rate=sampling_rate,
61
+ chunk_length_s=chunk_length_s,
62
+ stream_chunk_s=stream_chunk_s,
63
+ )
64
+
65
+ print("Start speaking...")
66
+ for item in transcriber(mic, generate_kwargs={"max_new_tokens": 128}):
67
+ sys.stdout.write("\033[K")
68
+ print(item["text"], end="\r")
69
+ if not item["partial"][0]:
70
+ break
71
+
72
+ return item["text"]
73
+
74
+
75
+
76
+ client = InferenceClient(
77
+ provider="fireworks-ai",
78
+ api_key=HF_TOKEN
79
+ )
80
+
81
+ def query(text, model_id="meta-llama/Llama-3.1-8B-Instruct"):
82
+ try:
83
+ completion = client.chat.completions.create(
84
+ model=model_id,
85
+ messages=[{"role": "user", "content": text}]
86
+ )
87
+ return completion.choices[0].message.content
88
+
89
+ except Exception as e:
90
+ print(f"Erreur: {str(e)}")
91
+ return None
92
+
93
+
94
+
95
+ processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
96
+ model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
97
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
98
+
99
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
100
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
101
+
102
+
103
+ def synthesise(text):
104
+ input_ids = processor(text=text, return_tensors="pt")["input_ids"]
105
+ speech = model.generate_speech(
106
+ input_ids.to(device),
107
+ speaker_embeddings.to(device),
108
+ vocoder=vocoder
109
+ )
110
+ return speech.cpu()
111
+
112
+ # launch_fn(debug=True)
113
+ # transcription = transcribe()
114
+ # response = query(transcription)
115
+ # audio = synthesise(response)
116
+ #
117
+ # sd.play(audio.numpy(), 16000)
118
+ # sd.wait()
119
+
120
+ # Interface Gradio
121
+ def assistant_vocal_interface():
122
+ launch_fn(debug=True)
123
+ transcription = transcribe()
124
+ response = query(transcription)
125
+ audio = synthesise(response)
126
+ return transcription, response, (16000, audio.numpy())
127
+
128
+ with gr.Blocks(title="Assistant Vocal") as demo:
129
+ gr.Markdown("## Assistant vocal : détection, transcription, génération et synthèse")
130
+
131
+ start_btn = gr.Button("Démarrer l'assistant")
132
+ transcription_box = gr.Textbox(label="Transcription")
133
+ response_box = gr.Textbox(label="Réponse IA")
134
+ audio_output = gr.Audio(label="Synthèse vocale", type="numpy", autoplay=True)
135
+
136
+ start_btn.click(
137
+ assistant_vocal_interface,
138
+ inputs=[],
139
+ outputs=[transcription_box, response_box, audio_output]
140
+ )
141
+
142
+ demo.launch(share=True)