hashhac commited on
Commit
7afe31a
·
1 Parent(s): e724e7e
Files changed (2) hide show
  1. app.py +216 -37
  2. requirements.txt +6 -15
app.py CHANGED
@@ -3,49 +3,228 @@ from fastrtc import (
3
  audio_to_bytes, aggregate_bytes_to_16bit
4
  )
5
  import gradio as gr
6
- from groq import Groq
7
  import numpy as np
8
- import anthropic
9
- from elevenlabs import ElevenLabs
10
-
11
- groq_client = Groq()
12
- claude_client = anthropic.Anthropic()
13
- tts_client = ElevenLabs()
14
-
15
-
16
- # See "Talk to Claude" in Cookbook for an example of how to keep
17
- # track of the chat history.
18
- def response(
19
- audio: tuple[int, np.ndarray],
20
- ):
21
- prompt = groq_client.audio.transcriptions.create(
22
- file=("audio-file.mp3", audio_to_bytes(audio)),
23
- model="whisper-large-v3-turbo",
24
- response_format="verbose_json",
25
- ).text
26
- response = claude_client.messages.create(
27
- model="claude-3-5-haiku-20241022",
28
- max_tokens=512,
29
- messages=[{"role": "user", "content": prompt}],
 
 
 
 
30
  )
31
- response_text = " ".join(
32
- block.text
33
- for block in response.content
34
- if getattr(block, "type", None) == "text"
 
 
 
 
 
 
 
 
 
 
 
35
  )
36
- iterator = tts_client.text_to_speech.convert_as_stream(
37
- text=response_text,
38
- voice_id="JBFqnCBsd6RMkjVDRZzb",
39
- model_id="eleven_multilingual_v2",
40
- output_format="pcm_24000"
41
-
 
 
 
 
42
  )
43
- for chunk in aggregate_bytes_to_16bit(iterator):
44
- audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1)
45
- yield (24000, audio_array)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  stream = Stream(
48
  modality="audio",
49
  mode="send-receive",
50
  handler=ReplyOnPause(response),
51
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  audio_to_bytes, aggregate_bytes_to_16bit
4
  )
5
  import gradio as gr
 
6
  import numpy as np
7
+ import torch
8
+ import os
9
+ from transformers import (
10
+ AutoModelForSpeechSeq2Seq,
11
+ AutoProcessor,
12
+ pipeline,
13
+ AutoTokenizer,
14
+ AutoModelForCausalLM,
15
+ AutoModelForSeq2SeqLM
16
+ )
17
+ from datasets import load_dataset
18
+ import scipy
19
+
20
+ # Check if CUDA is available, otherwise use CPU
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
23
+
24
+ # Step 1: Audio transcription with Whisper
25
+ def load_asr_model():
26
+ model_id = "openai/whisper-small" # Smaller version that's more efficient
27
+
28
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
29
+ model_id,
30
+ torch_dtype=torch_dtype,
31
+ low_cpu_mem_usage=True,
32
+ use_safetensors=True
33
  )
34
+ model.to(device)
35
+
36
+ processor = AutoProcessor.from_pretrained(model_id)
37
+
38
+ return pipeline(
39
+ "automatic-speech-recognition",
40
+ model=model,
41
+ tokenizer=processor.tokenizer,
42
+ feature_extractor=processor.feature_extractor,
43
+ max_new_tokens=128,
44
+ chunk_length_s=30,
45
+ batch_size=16,
46
+ return_timestamps=False,
47
+ torch_dtype=torch_dtype,
48
+ device=device,
49
  )
50
+
51
+ # Step 2: Text generation with a smaller LLM
52
+ def load_llm_model():
53
+ model_id = "facebook/opt-1.3b" # A smaller language model
54
+
55
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
56
+ model = AutoModelForCausalLM.from_pretrained(
57
+ model_id,
58
+ torch_dtype=torch_dtype,
59
+ low_cpu_mem_usage=True
60
  )
61
+ model.to(device)
62
+
63
+ return model, tokenizer
64
+
65
+ # Step 3: Text-to-Speech with a free model
66
+ def load_tts_model():
67
+ model_id = "microsoft/speecht5_tts"
68
+ processor = AutoProcessor.from_pretrained(model_id)
69
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
70
+ model.to(device)
71
+
72
+ # Load vocoder for waveform generation
73
+ vocoder_id = "microsoft/speecht5_hifigan"
74
+ vocoder = AutoModelForCausalLM.from_pretrained(vocoder_id)
75
+ vocoder.to(device)
76
+
77
+ # Load speaker embeddings
78
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
79
+ speaker_embeddings = torch.tensor(embeddings_dataset[7]["xvector"]).unsqueeze(0)
80
+
81
+ return model, processor, vocoder, speaker_embeddings
82
+
83
+ # Initialize all models
84
+ print("Loading ASR model...")
85
+ asr_pipeline = load_asr_model()
86
+
87
+ print("Loading LLM model...")
88
+ llm_model, llm_tokenizer = load_llm_model()
89
+
90
+ print("Loading TTS model...")
91
+ tts_model, tts_processor, tts_vocoder, speaker_embeddings = load_tts_model()
92
+
93
+ # Chat history management
94
+ chat_history = []
95
+
96
+ def generate_response(prompt):
97
+ # If chat history is empty, add a system message
98
+ if not chat_history:
99
+ chat_history.append({"role": "system", "content": "You are a helpful, friendly AI assistant. Keep your responses concise and conversational."})
100
+
101
+ # Add user message to history
102
+ chat_history.append({"role": "user", "content": prompt})
103
+
104
+ # Prepare input for the model
105
+ full_prompt = ""
106
+ for message in chat_history:
107
+ if message["role"] == "system":
108
+ full_prompt += f"System: {message['content']}\n"
109
+ elif message["role"] == "user":
110
+ full_prompt += f"User: {message['content']}\n"
111
+ elif message["role"] == "assistant":
112
+ full_prompt += f"Assistant: {message['content']}\n"
113
+
114
+ full_prompt += "Assistant: "
115
+
116
+ # Generate response
117
+ inputs = llm_tokenizer(full_prompt, return_tensors="pt").to(device)
118
+ with torch.no_grad():
119
+ output = llm_model.generate(
120
+ **inputs,
121
+ max_new_tokens=128,
122
+ do_sample=True,
123
+ temperature=0.7,
124
+ top_p=0.9
125
+ )
126
+
127
+ response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True)
128
+ response_text = response_text.split("Assistant: ")[-1].strip()
129
+
130
+ # Add assistant response to history
131
+ chat_history.append({"role": "assistant", "content": response_text})
132
+
133
+ # Keep history at a reasonable size
134
+ if len(chat_history) > 10:
135
+ # Keep system message and last 9 exchanges
136
+ chat_history.pop(1)
137
+
138
+ return response_text
139
+
140
+ def text_to_speech(text):
141
+ # Prepare inputs
142
+ inputs = tts_processor(text=text, return_tensors="pt")
143
+
144
+ # Add speaker embeddings
145
+ inputs["speaker_embeddings"] = speaker_embeddings.to(device)
146
+
147
+ # Generate speech
148
+ with torch.no_grad():
149
+ speech = tts_model.generate_speech(
150
+ inputs["input_ids"].to(device),
151
+ speaker_embeddings.to(device)
152
+ )
153
+
154
+ # Convert to waveform using vocoder
155
+ with torch.no_grad():
156
+ waveform = tts_vocoder(speech)
157
+
158
+ # Convert to numpy array
159
+ audio_array = waveform.cpu().numpy().squeeze()
160
+
161
+ # Normalize and convert to int16
162
+ audio_array = (audio_array / np.max(np.abs(audio_array)) * 32767).astype(np.int16)
163
+
164
+ # Reshape for fastrtc
165
+ audio_array = audio_array.reshape(1, -1)
166
+
167
+ return (24000, audio_array) # Using 24kHz sample rate
168
+
169
+ def response(audio: tuple[int, np.ndarray]):
170
+ # Step 1: Speech-to-Text
171
+ transcript = asr_pipeline({"sampling_rate": audio[0], "raw": audio[1].flatten()})
172
+ prompt = transcript["text"]
173
+
174
+ # Step 2: Generate text response
175
+ response_text = generate_response(prompt)
176
+
177
+ # Step 3: Text-to-Speech
178
+ sample_rate, audio_array = text_to_speech(response_text)
179
+
180
+ # Convert to expected format
181
+ chunk_size = 4800 # 200ms chunks at 24kHz
182
+ for i in range(0, audio_array.shape[1], chunk_size):
183
+ chunk = audio_array[:, i:i+chunk_size]
184
+ if chunk.size > 0: # Ensure we don't yield empty chunks
185
+ yield (sample_rate, chunk)
186
 
187
  stream = Stream(
188
  modality="audio",
189
  mode="send-receive",
190
  handler=ReplyOnPause(response),
191
+ )
192
+
193
+ # For testing without WebRTC
194
+ def demo():
195
+ with gr.Blocks() as demo:
196
+ gr.Markdown("# Local Voice Chatbot")
197
+ audio_input = gr.Audio(sources=["microphone"], type="numpy")
198
+ audio_output = gr.Audio()
199
+
200
+ def process_audio(audio):
201
+ if audio is None:
202
+ return None
203
+
204
+ sample_rate, audio_array = audio
205
+ transcript = asr_pipeline({"sampling_rate": sample_rate, "raw": audio_array.flatten()})
206
+ prompt = transcript["text"]
207
+ print(f"Transcribed: {prompt}")
208
+
209
+ response_text = generate_response(prompt)
210
+ print(f"Response: {response_text}")
211
+
212
+ sample_rate, audio_array = text_to_speech(response_text)
213
+ return (sample_rate, audio_array[0])
214
+
215
+ audio_input.change(process_audio, inputs=[audio_input], outputs=[audio_output])
216
+
217
+ demo.launch()
218
+
219
+ if __name__ == "__main__":
220
+ import argparse
221
+ parser = argparse.ArgumentParser()
222
+ parser.add_argument("--demo", action="store_true", help="Run Gradio demo instead of WebRTC")
223
+ args = parser.parse_args()
224
+
225
+ if args.demo:
226
+ demo()
227
+ else:
228
+ # For running with FastRTC
229
+ # You would need to add your FastRTC server code here
230
+ pass
requirements.txt CHANGED
@@ -1,16 +1,7 @@
1
- fastapi
2
- uvicorn
3
- transformers
4
- torch
5
- numpy
6
- # librosa
7
- python-dotenv
8
- fastrtc[vad, tts]
9
- # SentencePiece
10
- # twilio
11
  gradio
12
- # torchaudio
13
- elevenlabs
14
- groq
15
- anthropic
16
- ffmpeg
 
1
+ transformers
2
+ torch
3
+ datasets
4
+ scipy
5
+ fastrtc
 
 
 
 
 
6
  gradio
7
+ accelerate