hashhac commited on
Commit
4fb650d
·
1 Parent(s): 95dee6f
Files changed (2) hide show
  1. app.py +261 -0
  2. requirements.txt +15 -0
app.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ import tempfile
6
+ from transformers import (
7
+ AutoModelForSpeechSeq2Seq,
8
+ AutoProcessor,
9
+ pipeline,
10
+ AutoTokenizer,
11
+ AutoModelForCausalLM
12
+ )
13
+
14
+ # Check if CUDA is available, otherwise use CPU
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
17
+
18
+ # Initialize pyttsx3 for local TTS
19
+ def load_local_tts():
20
+ import pyttsx3
21
+
22
+ engine = pyttsx3.init()
23
+ engine.setProperty('rate', 150) # Speed of speech
24
+ engine.setProperty('volume', 0.9) # Volume
25
+
26
+ voices = engine.getProperty('voices')
27
+ if len(voices) > 1:
28
+ engine.setProperty('voice', voices[1].id) # Set female voice
29
+
30
+ return engine
31
+
32
+ # Initialize the TTS engine
33
+ print("Loading local TTS engine...")
34
+ tts_engine = load_local_tts()
35
+
36
+ def text_to_speech_local(text):
37
+ """Convert text to speech using pyttsx3 local TTS engine"""
38
+ import tempfile
39
+ import soundfile as sf
40
+
41
+ # Create a temporary file to store the audio
42
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
43
+ temp_filename = temp_file.name
44
+
45
+ # Generate speech to the temporary file
46
+ tts_engine.save_to_file(text, temp_filename)
47
+ tts_engine.runAndWait()
48
+
49
+ # Read the audio file
50
+ audio_data, sample_rate = sf.read(temp_filename)
51
+
52
+ # Convert to the expected format
53
+ if len(audio_data.shape) == 1:
54
+ audio_data = audio_data.reshape(1, -1)
55
+ else:
56
+ audio_data = audio_data[:, 0].reshape(1, -1)
57
+
58
+ # Ensure it's int16
59
+ audio_data = (audio_data * 32767).astype(np.int16)
60
+
61
+ # Clean up the temporary file
62
+ os.unlink(temp_filename)
63
+
64
+ return (sample_rate, audio_data)
65
+
66
+ # Load ASR model (Whisper)
67
+ def load_asr_model():
68
+ model_id = "openai/whisper-small"
69
+
70
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
71
+ model_id,
72
+ torch_dtype=torch_dtype,
73
+ low_cpu_mem_usage=True,
74
+ use_safetensors=True
75
+ )
76
+ model.to(device)
77
+
78
+ processor = AutoProcessor.from_pretrained(model_id)
79
+
80
+ return pipeline(
81
+ "automatic-speech-recognition",
82
+ model=model,
83
+ tokenizer=processor.tokenizer,
84
+ feature_extractor=processor.feature_extractor,
85
+ max_new_tokens=128,
86
+ chunk_length_s=30,
87
+ batch_size=16,
88
+ return_timestamps=False,
89
+ torch_dtype=torch_dtype,
90
+ device=device,
91
+ )
92
+
93
+ # Load LLM model
94
+ def load_llm_model():
95
+ model_id = "facebook/opt-1.3b"
96
+
97
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
98
+
99
+ if tokenizer.pad_token is None:
100
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
101
+
102
+ model = AutoModelForCausalLM.from_pretrained(
103
+ model_id,
104
+ torch_dtype=torch_dtype,
105
+ low_cpu_mem_usage=True
106
+ )
107
+
108
+ model.resize_token_embeddings(len(tokenizer))
109
+ model.config.pad_token_id = tokenizer.pad_token_id
110
+
111
+ if hasattr(model.config, "word_embed_proj_dim"):
112
+ model.config._remove_wrong_keys = False
113
+
114
+ model.to(device)
115
+
116
+ return model, tokenizer
117
+
118
+ # Initialize models
119
+ print("Loading ASR model...")
120
+ asr_pipeline = load_asr_model()
121
+
122
+ print("Loading LLM model...")
123
+ llm_model, llm_tokenizer = load_llm_model()
124
+
125
+ # Chat history management
126
+ chat_history = []
127
+
128
+ def generate_response(prompt):
129
+ # If chat history is empty, add a system message
130
+ if not chat_history:
131
+ chat_history.append({"role": "system", "content": "You are a helpful, friendly AI assistant. Keep your responses concise and conversational."})
132
+
133
+ # Add user message to history
134
+ chat_history.append({"role": "user", "content": prompt})
135
+
136
+ # Build prompt from chat history
137
+ full_prompt = ""
138
+ for message in chat_history:
139
+ if message["role"] == "system":
140
+ full_prompt += f"System: {message['content']}\n"
141
+ elif message["role"] == "user":
142
+ full_prompt += f"User: {message['content']}\n"
143
+ elif message["role"] == "assistant":
144
+ full_prompt += f"Assistant: {message['content']}\n"
145
+
146
+ full_prompt += "Assistant: "
147
+
148
+ # Encode input
149
+ encoded_input = llm_tokenizer.encode_plus(
150
+ full_prompt,
151
+ return_tensors="pt",
152
+ padding=False,
153
+ add_special_tokens=True,
154
+ return_attention_mask=True
155
+ )
156
+
157
+ input_ids = encoded_input["input_ids"].to(device)
158
+ attention_mask = torch.ones_like(input_ids).to(device)
159
+
160
+ # Generate response
161
+ with torch.no_grad():
162
+ try:
163
+ output = llm_model.generate(
164
+ input_ids=input_ids,
165
+ attention_mask=attention_mask,
166
+ max_new_tokens=128,
167
+ do_sample=True,
168
+ temperature=0.7,
169
+ top_p=0.9,
170
+ pad_token_id=llm_tokenizer.pad_token_id,
171
+ eos_token_id=llm_tokenizer.eos_token_id,
172
+ use_cache=True
173
+ )
174
+ except Exception as e:
175
+ output = llm_model.generate(
176
+ input_ids=input_ids,
177
+ max_new_tokens=128,
178
+ do_sample=True,
179
+ temperature=0.7
180
+ )
181
+
182
+ # Decode output
183
+ response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True)
184
+ response_text = response_text.split("Assistant: ")[-1].strip()
185
+
186
+ # Add assistant response to history
187
+ chat_history.append({"role": "assistant", "content": response_text})
188
+
189
+ # Keep history manageable
190
+ if len(chat_history) > 10:
191
+ chat_history.pop(1)
192
+
193
+ return response_text
194
+
195
+ def demo():
196
+ with gr.Blocks() as demo:
197
+ gr.Markdown("# Voice Chatbot")
198
+ gr.Markdown("Simply speak into the microphone and get an audio response.")
199
+
200
+ audio_input = gr.Audio(sources=["microphone"], type="numpy", label="Speak")
201
+ audio_output = gr.Audio(label="Response", autoplay=True)
202
+ transcript_display = gr.Textbox(label="Conversation")
203
+
204
+ def process_audio(audio):
205
+ if audio is None:
206
+ return None, "No audio detected."
207
+
208
+ # Track conversation for display
209
+ conversation_text = ""
210
+
211
+ # Process audio
212
+ sample_rate, audio_array = audio
213
+
214
+ # Convert to float32 for ASR
215
+ audio_float32 = audio_array.flatten().astype(np.float32) / 32768.0
216
+
217
+ # Speech-to-text
218
+ transcript = asr_pipeline({
219
+ "sampling_rate": sample_rate,
220
+ "raw": audio_float32
221
+ })
222
+
223
+ prompt = transcript["text"]
224
+ conversation_text += f"You: {prompt}\n"
225
+ print(f"Transcribed: {prompt}")
226
+
227
+ # Generate response
228
+ response_text = generate_response(prompt)
229
+ conversation_text += f"Assistant: {response_text}\n"
230
+ print(f"Response: {response_text}")
231
+
232
+ # Convert to speech
233
+ sample_rate, audio_array = text_to_speech_local(response_text)
234
+
235
+ # Concatenate chunks for Gradio
236
+ full_audio = np.concatenate([audio_array[:, i:i+int(sample_rate*0.2)]
237
+ for i in range(0, audio_array.shape[1], int(sample_rate*0.2))
238
+ if audio_array[:, i:i+int(sample_rate*0.2)].size > 0], axis=1)
239
+
240
+ return (sample_rate, full_audio), conversation_text
241
+
242
+ audio_input.change(process_audio,
243
+ inputs=[audio_input],
244
+ outputs=[audio_output, transcript_display])
245
+
246
+ clear_btn = gr.Button("Clear Conversation")
247
+ clear_btn.click(lambda: (None, ""), outputs=[audio_output, transcript_display])
248
+
249
+ # Add function to clear chat history
250
+ def reset_chat():
251
+ global chat_history
252
+ chat_history = []
253
+ return None, "Conversation history cleared."
254
+
255
+ reset_btn = gr.Button("Reset Chat History")
256
+ reset_btn.click(reset_chat, outputs=[audio_output, transcript_display])
257
+
258
+ demo.launch()
259
+
260
+ if __name__ == "__main__":
261
+ demo()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ datasets
4
+ scipy
5
+ fastrtc
6
+ gradio
7
+ accelerate
8
+ sentencepiece
9
+ fastrtc[vad,tts]
10
+ torchaudio
11
+ gtts
12
+ pydub
13
+ scipy
14
+ pyttsx3
15
+ soundfile