ceymox commited on
Commit
78c273c
·
verified ·
1 Parent(s): 5df9635

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1616 -0
app.py ADDED
@@ -0,0 +1,1616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import time
4
+ import torch
5
+ import librosa
6
+ import requests
7
+ import tempfile
8
+ import threading
9
+ import queue
10
+ import traceback
11
+ import numpy as np
12
+ import soundfile as sf
13
+ import gradio as gr
14
+ from datetime import datetime
15
+ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, pipeline, logging as trf_logging
16
+ from huggingface_hub import login, hf_hub_download, scan_cache_dir
17
+ import speech_recognition as sr
18
+ import openai
19
+
20
+ # Set up environment variables and timeouts
21
+ os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "300" # 5-minute timeout
22
+
23
+ # Enable verbose logging
24
+ trf_logging.set_verbosity_info()
25
+
26
+ # Get API keys from environment
27
+ HF_TOKEN = os.getenv("HF_TOKEN")
28
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
29
+
30
+ # Set OpenAI API key
31
+ openai.api_key = OPENAI_API_KEY
32
+
33
+ # Login to Hugging Face
34
+ if HF_TOKEN:
35
+ print("🔐 Logging into Hugging Face with token...")
36
+ login(token=HF_TOKEN)
37
+ else:
38
+ print("⚠️ HF_TOKEN not found. Proceeding without login...")
39
+
40
+ # Set up device (GPU if available, otherwise CPU)
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ print(f"🔧 Using device: {device}")
43
+
44
+ # Initialize model variables
45
+ tts_model = None
46
+ asr_model = None
47
+
48
+ # Define repository IDs
49
+ tts_repo_id = "ai4bharat/IndicF5"
50
+ asr_repo_id = "facebook/wav2vec2-large-xlsr-53" # Multilingual ASR model
51
+
52
+ # TTS model wrapper class to standardize the interface
53
+ class TTSModelWrapper:
54
+ def __init__(self, model):
55
+ self.model = model
56
+
57
+ def generate(self, text, ref_audio_path, ref_text):
58
+ try:
59
+ if self.model is None:
60
+ raise ValueError("Model not initialized")
61
+
62
+ output = self.model(
63
+ text,
64
+ ref_audio_path=ref_audio_path,
65
+ ref_text=ref_text
66
+ )
67
+ return output
68
+ except Exception as e:
69
+ print(f"Error in TTS generation: {e}")
70
+ traceback.print_exc()
71
+ return None
72
+
73
+ def load_tts_model_with_retry(max_retries=3, retry_delay=5):
74
+ global tts_model, tts_model_wrapper
75
+
76
+ # First, check if model is already in cache
77
+ print("Checking if TTS model is in cache...")
78
+ try:
79
+ cache_info = scan_cache_dir()
80
+ model_in_cache = any(tts_repo_id in repo.repo_id for repo in cache_info.repos)
81
+ if model_in_cache:
82
+ print(f"Model {tts_repo_id} found in cache, loading locally...")
83
+ tts_model = AutoModel.from_pretrained(
84
+ tts_repo_id,
85
+ trust_remote_code=True,
86
+ local_files_only=True
87
+ ).to(device)
88
+ tts_model_wrapper = TTSModelWrapper(tts_model)
89
+ print("TTS model loaded from cache successfully!")
90
+ return
91
+ except Exception as e:
92
+ print(f"Cache check failed: {e}")
93
+
94
+ # If not in cache or cache check failed, try loading with retries
95
+ for attempt in range(max_retries):
96
+ try:
97
+ print(f"Loading {tts_repo_id} model (attempt {attempt+1}/{max_retries})...")
98
+ tts_model = AutoModel.from_pretrained(
99
+ tts_repo_id,
100
+ trust_remote_code=True,
101
+ revision="main",
102
+ use_auth_token=HF_TOKEN,
103
+ low_cpu_mem_usage=True
104
+ ).to(device)
105
+
106
+ tts_model_wrapper = TTSModelWrapper(tts_model)
107
+ print(f"TTS model loaded successfully! Type: {type(tts_model)}")
108
+ return # Success, exit function
109
+
110
+ except Exception as e:
111
+ print(f"⚠️ Attempt {attempt+1}/{max_retries} failed: {e}")
112
+ if attempt < max_retries - 1:
113
+ print(f"Waiting {retry_delay} seconds before retrying...")
114
+ time.sleep(retry_delay)
115
+ retry_delay *= 1.5 # Exponential backoff
116
+
117
+ # If all attempts failed, try one last time with fallback options
118
+ try:
119
+ print("Trying with fallback options...")
120
+ tts_model = AutoModel.from_pretrained(
121
+ tts_repo_id,
122
+ trust_remote_code=True,
123
+ revision="main",
124
+ local_files_only=False,
125
+ use_auth_token=HF_TOKEN,
126
+ force_download=False,
127
+ resume_download=True
128
+ ).to(device)
129
+ tts_model_wrapper = TTSModelWrapper(tts_model)
130
+ print("TTS model loaded with fallback options!")
131
+ except Exception as e2:
132
+ print(f"❌ All attempts to load TTS model failed: {e2}")
133
+ print("Will continue without TTS model loaded.")
134
+
135
+ def load_asr_model():
136
+ global asr_model
137
+ try:
138
+ print(f"Loading ASR model from {asr_repo_id}...")
139
+ asr_model = pipeline("automatic-speech-recognition", model=asr_repo_id, device=device)
140
+ print("ASR model loaded successfully!")
141
+ except Exception as e:
142
+ print(f"Error loading ASR model: {e}")
143
+ print("Will use Google's speech recognition API instead.")
144
+ asr_model = None
145
+
146
+ class SpeechRecognizer:
147
+ def __init__(self):
148
+ self.recognizer = sr.Recognizer()
149
+ self.using_huggingface = asr_model is not None
150
+
151
+ def recognize_from_file(self, audio_path, language="ml-IN"):
152
+ """Recognize speech from audio file with fallback mechanisms"""
153
+ print(f"Recognizing speech from {audio_path}")
154
+ try:
155
+ # Try Hugging Face model first if available
156
+ if self.using_huggingface:
157
+ try:
158
+ result = asr_model(audio_path)
159
+ transcription = result["text"]
160
+ print(f"HF ASR result: {transcription}")
161
+ return transcription
162
+ except Exception as e:
163
+ print(f"HF ASR failed: {e}, falling back to Google")
164
+
165
+ # Fallback to Google's ASR
166
+ with sr.AudioFile(audio_path) as source:
167
+ audio_data = self.recognizer.record(source)
168
+ text = self.recognizer.recognize_google(audio_data, language=language)
169
+ print(f"Google ASR result: {text}")
170
+ return text
171
+ except Exception as e:
172
+ print(f"Speech recognition failed: {e}")
173
+ return ""
174
+
175
+ def recognize_from_microphone(self, language="ml-IN", timeout=5):
176
+ """Recognize speech from microphone"""
177
+ print("Listening to microphone...")
178
+ try:
179
+ with sr.Microphone() as source:
180
+ self.recognizer.adjust_for_ambient_noise(source)
181
+ print("Speak now...")
182
+ try:
183
+ audio = self.recognizer.listen(source, timeout=timeout)
184
+ print("Processing speech...")
185
+
186
+ # Save audio to temporary file for potential HF model processing
187
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
188
+ temp_file.close()
189
+
190
+ with open(temp_file.name, "wb") as f:
191
+ f.write(audio.get_wav_data())
192
+
193
+ # Process with available model
194
+ if self.using_huggingface:
195
+ try:
196
+ result = asr_model(temp_file.name)
197
+ text = result["text"]
198
+ print(f"HF ASR result: {text}")
199
+ os.unlink(temp_file.name)
200
+ return text
201
+ except Exception as e:
202
+ print(f"HF ASR failed: {e}, falling back to Google")
203
+
204
+ # Fallback to Google
205
+ text = self.recognizer.recognize_google(audio, language=language)
206
+ print(f"Google ASR result: {text}")
207
+ os.unlink(temp_file.name)
208
+ return text
209
+
210
+ except sr.WaitTimeoutError:
211
+ print("No speech detected within timeout period")
212
+ return ""
213
+ except Exception as e:
214
+ print(f"Speech recognition error: {e}")
215
+ return ""
216
+ except Exception as e:
217
+ print(f"Microphone access error: {e}")
218
+ return ""
219
+
220
+ class ConversationManager:
221
+ def __init__(self):
222
+ self.conversation_history = []
223
+ self.system_prompt = (
224
+ "You are a helpful, friendly assistant who speaks Malayalam fluently. "
225
+ "Keep your responses concise and conversational. "
226
+ "If the user speaks in English, you can respond in English. "
227
+ "If the user speaks in Malayalam, respond in Malayalam."
228
+ )
229
+
230
+ def add_message(self, role, content):
231
+ self.conversation_history.append({"role": role, "content": content})
232
+
233
+ def get_formatted_history(self):
234
+ """Format conversation history for OpenAI API"""
235
+ messages = [{"role": "system", "content": self.system_prompt}]
236
+
237
+ for msg in self.conversation_history:
238
+ if msg["role"] == "user":
239
+ messages.append({"role": "user", "content": msg["content"]})
240
+ else:
241
+ messages.append({"role": "assistant", "content": msg["content"]})
242
+
243
+ return messages
244
+
245
+ def generate_response(self, user_input):
246
+ """Generate response using GPT-3.5 Turbo"""
247
+ if not openai.api_key:
248
+ return "I'm sorry, but the language model is not available right now."
249
+
250
+ self.add_message("user", user_input)
251
+
252
+ try:
253
+ # Format history for the model
254
+ messages = self.get_formatted_history()
255
+ print(f"Sending prompt to OpenAI: {len(messages)} messages")
256
+
257
+ # Generate response with GPT-3.5 Turbo
258
+ response = openai.ChatCompletion.create(
259
+ model="gpt-3.5-turbo",
260
+ messages=messages,
261
+ max_tokens=300,
262
+ temperature=0.7,
263
+ top_p=0.9,
264
+ )
265
+
266
+ # Extract text response
267
+ response_text = response.choices[0].message["content"].strip()
268
+ print(f"GPT-3.5 response: {response_text}")
269
+
270
+ # Add to history
271
+ self.add_message("assistant", response_text)
272
+
273
+ return response_text
274
+
275
+ except Exception as e:
276
+ print(f"Error generating response: {e}")
277
+ fallback_response = "I'm having trouble thinking right now. Can we try again?"
278
+ self.add_message("assistant", fallback_response)
279
+ return fallback_response
280
+
281
+ def remove_noise(audio_data, threshold=0.01):
282
+ """Apply simple noise gate to remove low-level noise"""
283
+ if audio_data is None:
284
+ return np.zeros(1000)
285
+
286
+ # Convert to numpy if needed
287
+ if isinstance(audio_data, torch.Tensor):
288
+ audio_data = audio_data.detach().cpu().numpy()
289
+ if isinstance(audio_data, list):
290
+ audio_data = np.array(audio_data)
291
+
292
+ # Apply noise gate
293
+ noise_mask = np.abs(audio_data) < threshold
294
+ clean_audio = audio_data.copy()
295
+ clean_audio[noise_mask] = 0
296
+
297
+ return clean_audio
298
+
299
+ def apply_smoothing(audio_data, window_size=5):
300
+ """Apply gentle smoothing to reduce artifacts"""
301
+ if audio_data is None or len(audio_data) < window_size*2:
302
+ return audio_data
303
+
304
+ # Simple moving average filter
305
+ kernel = np.ones(window_size) / window_size
306
+ smoothed = np.convolve(audio_data, kernel, mode='same')
307
+
308
+ # Keep original at the edges
309
+ smoothed[:window_size] = audio_data[:window_size]
310
+ smoothed[-window_size:] = audio_data[-window_size:]
311
+
312
+ return smoothed
313
+
314
+ def enhance_audio(audio_data):
315
+ """Process audio to improve quality and reduce noise"""
316
+ if audio_data is None:
317
+ return np.zeros(1000)
318
+
319
+ # Ensure numpy array
320
+ if isinstance(audio_data, torch.Tensor):
321
+ audio_data = audio_data.detach().cpu().numpy()
322
+ if isinstance(audio_data, list):
323
+ audio_data = np.array(audio_data)
324
+
325
+ # Ensure correct shape and dtype
326
+ if len(audio_data.shape) > 1:
327
+ audio_data = audio_data.flatten()
328
+ if audio_data.dtype != np.float32:
329
+ audio_data = audio_data.astype(np.float32)
330
+
331
+ # Skip processing if audio is empty or too short
332
+ if audio_data.size < 100:
333
+ return audio_data
334
+
335
+ # Check if the audio has reasonable amplitude
336
+ rms = np.sqrt(np.mean(audio_data**2))
337
+ print(f"Initial RMS: {rms}")
338
+
339
+ # Apply gain if needed
340
+ if rms < 0.05: # Very quiet
341
+ target_rms = 0.2
342
+ gain = target_rms / max(rms, 0.0001)
343
+ print(f"Applying gain factor: {gain}")
344
+ audio_data = audio_data * gain
345
+
346
+ # Remove DC offset
347
+ audio_data = audio_data - np.mean(audio_data)
348
+
349
+ # Apply noise gate to remove low-level noise
350
+ audio_data = remove_noise(audio_data, threshold=0.01)
351
+
352
+ # Apply gentle smoothing to reduce artifacts
353
+ audio_data = apply_smoothing(audio_data, window_size=3)
354
+
355
+ # Apply soft limiting to prevent clipping
356
+ max_amp = np.max(np.abs(audio_data))
357
+ if max_amp > 0.95:
358
+ audio_data = 0.95 * audio_data / max_amp
359
+
360
+ # Apply subtle compression for better audibility
361
+ audio_data = np.tanh(audio_data * 1.1) * 0.9
362
+
363
+ return audio_data
364
+
365
+ def split_into_chunks(text, max_length=30):
366
+ """Split text into smaller chunks based on punctuation and length"""
367
+ # First split by sentences
368
+ sentence_markers = ['.', '?', '!', ';', ':', '।', '॥']
369
+ chunks = []
370
+ current = ""
371
+
372
+ # Initial coarse splitting by sentence markers
373
+ for char in text:
374
+ current += char
375
+ if char in sentence_markers and current.strip():
376
+ chunks.append(current.strip())
377
+ current = ""
378
+
379
+ if current.strip():
380
+ chunks.append(current.strip())
381
+
382
+ # Further break down long sentences
383
+ final_chunks = []
384
+ for chunk in chunks:
385
+ if len(chunk) <= max_length:
386
+ final_chunks.append(chunk)
387
+ else:
388
+ # Try splitting by commas for long sentences
389
+ comma_splits = chunk.split(',')
390
+ current_part = ""
391
+
392
+ for part in comma_splits:
393
+ if len(current_part) + len(part) <= max_length:
394
+ if current_part:
395
+ current_part += ","
396
+ current_part += part
397
+ else:
398
+ if current_part:
399
+ final_chunks.append(current_part.strip())
400
+ current_part = part
401
+
402
+ if current_part:
403
+ final_chunks.append(current_part.strip())
404
+
405
+ print(f"Split text into {len(final_chunks)} chunks")
406
+ return final_chunks
407
+
408
+ class StreamingTTS:
409
+ def __init__(self):
410
+ self.is_generating = False
411
+ self.should_stop = False
412
+ self.temp_dir = None
413
+ self.ref_audio_path = None
414
+ self.output_file = None
415
+ self.all_chunks = []
416
+ self.sample_rate = 24000 # Default sample rate
417
+ self.current_text = "" # Track current text being processed
418
+
419
+ # Create temp directory
420
+ try:
421
+ self.temp_dir = tempfile.mkdtemp()
422
+ print(f"Created temp directory: {self.temp_dir}")
423
+ except Exception as e:
424
+ print(f"Error creating temp directory: {e}")
425
+ self.temp_dir = "." # Use current directory as fallback
426
+
427
+ def prepare_ref_audio(self, ref_audio, ref_sr):
428
+ """Prepare reference audio with enhanced quality"""
429
+ try:
430
+ if self.ref_audio_path is None:
431
+ self.ref_audio_path = os.path.join(self.temp_dir, "ref_audio.wav")
432
+
433
+ # Process the reference audio to ensure clean quality
434
+ ref_audio = enhance_audio(ref_audio)
435
+
436
+ # Save the reference audio
437
+ sf.write(self.ref_audio_path, ref_audio, ref_sr, format='WAV', subtype='FLOAT')
438
+ print(f"Saved reference audio to: {self.ref_audio_path}")
439
+
440
+ # Verify file was created
441
+ if os.path.exists(self.ref_audio_path):
442
+ print(f"Reference audio saved successfully: {os.path.getsize(self.ref_audio_path)} bytes")
443
+ else:
444
+ print("⚠️ Failed to create reference audio file!")
445
+
446
+ # Create output file
447
+ if self.output_file is None:
448
+ self.output_file = os.path.join(self.temp_dir, "output.wav")
449
+ print(f"Output will be saved to: {self.output_file}")
450
+ except Exception as e:
451
+ print(f"Error preparing reference audio: {e}")
452
+
453
+ def cleanup(self):
454
+ """Clean up temporary files"""
455
+ if self.temp_dir:
456
+ try:
457
+ if os.path.exists(self.ref_audio_path):
458
+ os.remove(self.ref_audio_path)
459
+ if os.path.exists(self.output_file):
460
+ os.remove(self.output_file)
461
+ os.rmdir(self.temp_dir)
462
+ self.temp_dir = None
463
+ print("Cleaned up temporary files")
464
+ except Exception as e:
465
+ print(f"Error cleaning up: {e}")
466
+
467
+ def generate(self, text, ref_audio, ref_sr, ref_text):
468
+ """Start generation in a new thread with validation"""
469
+ if self.is_generating:
470
+ print("Already generating speech, please wait")
471
+ return
472
+
473
+ # Store the text for verification
474
+ self.current_text = text
475
+ print(f"Setting current text to: '{self.current_text}'")
476
+
477
+ # Check model is loaded
478
+ if tts_model_wrapper is None or tts_model is None:
479
+ print("⚠️ Model is not loaded. Cannot generate speech.")
480
+ return
481
+
482
+ self.is_generating = True
483
+ self.should_stop = False
484
+ self.all_chunks = []
485
+
486
+ # Start in a new thread
487
+ threading.Thread(
488
+ target=self._process_streaming,
489
+ args=(text, ref_audio, ref_sr, ref_text),
490
+ daemon=True
491
+ ).start()
492
+
493
+ def _process_streaming(self, text, ref_audio, ref_sr, ref_text):
494
+ """Process text in chunks with high-quality audio generation"""
495
+ try:
496
+ # Double check text matches what we expect
497
+ if text != self.current_text:
498
+ print(f"⚠️ Text mismatch detected! Expected: '{self.current_text}', Got: '{text}'")
499
+ # Use the stored text to be safe
500
+ text = self.current_text
501
+
502
+ # Prepare reference audio
503
+ self.prepare_ref_audio(ref_audio, ref_sr)
504
+
505
+ # Print the text we're actually going to process
506
+ print(f"Processing text: '{text}'")
507
+
508
+ # Split text into smaller chunks for faster processing
509
+ chunks = split_into_chunks(text)
510
+ print(f"Processing {len(chunks)} chunks")
511
+
512
+ combined_audio = None
513
+ total_start_time = time.time()
514
+
515
+ # Process each chunk
516
+ for i, chunk in enumerate(chunks):
517
+ if self.should_stop:
518
+ print("Stopping generation as requested")
519
+ break
520
+
521
+ chunk_start = time.time()
522
+ print(f"Processing chunk {i+1}/{len(chunks)}: '{chunk}'")
523
+
524
+ # Generate speech for this chunk
525
+ try:
526
+ # Set timeout for inference
527
+ chunk_timeout = 30 # 30 seconds timeout per chunk
528
+
529
+ with torch.inference_mode():
530
+ # Explicitly pass the chunk text
531
+ chunk_audio = tts_model_wrapper.generate(
532
+ text=chunk, # Make sure we're using the current chunk
533
+ ref_audio_path=self.ref_audio_path,
534
+ ref_text=ref_text
535
+ )
536
+
537
+ if chunk_audio is None or (hasattr(chunk_audio, 'size') and chunk_audio.size == 0):
538
+ print("⚠️ Empty audio returned for this chunk")
539
+ chunk_audio = np.zeros(int(24000 * 0.5)) # 0.5s silence
540
+
541
+ # Process the audio to improve quality
542
+ chunk_audio = enhance_audio(chunk_audio)
543
+
544
+ chunk_time = time.time() - chunk_start
545
+ print(f"✓ Chunk {i+1} processed in {chunk_time:.2f}s")
546
+
547
+ # Add small silence between chunks
548
+ silence = np.zeros(int(24000 * 0.1)) # 0.1s silence
549
+ chunk_audio = np.concatenate([chunk_audio, silence])
550
+
551
+ # Add to our collection
552
+ self.all_chunks.append(chunk_audio)
553
+
554
+ # Combine all chunks so far
555
+ if combined_audio is None:
556
+ combined_audio = chunk_audio
557
+ else:
558
+ combined_audio = np.concatenate([combined_audio, chunk_audio])
559
+
560
+ # Process combined audio for consistent quality
561
+ processed_audio = enhance_audio(combined_audio)
562
+
563
+ # Write intermediate output
564
+ sf.write(self.output_file, processed_audio, 24000, format='WAV', subtype='FLOAT')
565
+
566
+ except Exception as e:
567
+ print(f"Error processing chunk {i+1}: {str(e)[:100]}")
568
+ continue
569
+
570
+ total_time = time.time() - total_start_time
571
+ print(f"Total generation time: {total_time:.2f}s")
572
+
573
+ except Exception as e:
574
+ print(f"Error in streaming TTS: {str(e)[:200]}")
575
+ # Try to write whatever we have so far
576
+ if len(self.all_chunks) > 0:
577
+ try:
578
+ combined = np.concatenate(self.all_chunks)
579
+ sf.write(self.output_file, combined, 24000, format='WAV', subtype='FLOAT')
580
+ print("Saved partial output")
581
+ except Exception as e2:
582
+ print(f"Failed to save partial output: {e2}")
583
+ finally:
584
+ self.is_generating = False
585
+ print("Generation complete")
586
+
587
+ def get_current_audio(self):
588
+ """Get current audio file path for Gradio"""
589
+ if self.output_file and os.path.exists(self.output_file):
590
+ file_size = os.path.getsize(self.output_file)
591
+ if file_size > 0:
592
+ return self.output_file
593
+ return None
594
+
595
+ class ConversationEngine:
596
+ def __init__(self):
597
+ self.conversation_history = []
598
+ self.system_prompt = "You are a helpful assistant that speaks Malayalam fluently. Always respond in Malayalam script with proper formatting."
599
+ self.saved_voice = None
600
+ self.saved_voice_text = ""
601
+ self.tts_cache = {} # Cache for TTS outputs
602
+
603
+ # TTS background processing queue
604
+ self.tts_queue = queue.Queue()
605
+ self.tts_thread = threading.Thread(target=self.tts_worker, daemon=True)
606
+ self.tts_thread.start()
607
+
608
+ # Initialize streaming TTS
609
+ self.streaming_tts = StreamingTTS()
610
+
611
+ def tts_worker(self):
612
+ """Background worker to process TTS requests"""
613
+ while True:
614
+ try:
615
+ # Get text and callback from queue
616
+ text, callback = self.tts_queue.get()
617
+
618
+ # Generate speech
619
+ audio_path = self._generate_tts(text)
620
+
621
+ # Execute callback with result
622
+ if callback:
623
+ callback(audio_path)
624
+
625
+ # Mark task as done
626
+ self.tts_queue.task_done()
627
+ except Exception as e:
628
+ print(f"Error in TTS worker: {e}")
629
+ traceback.print_exc()
630
+
631
+ def transcribe_audio(self, audio_data, language="ml-IN"):
632
+ """Convert audio to text using speech recognition"""
633
+ if audio_data is None:
634
+ print("No audio data received")
635
+ return "No audio detected", ""
636
+
637
+ # Make sure we have audio data in the expected format
638
+ try:
639
+ if isinstance(audio_data, tuple) and len(audio_data) == 2:
640
+ # Expected format: (sample_rate, audio_samples)
641
+ sample_rate, audio_samples = audio_data
642
+ else:
643
+ print(f"Unexpected audio format: {type(audio_data)}")
644
+ return "Invalid audio format", ""
645
+
646
+ if len(audio_samples) == 0:
647
+ print("Empty audio samples")
648
+ return "No speech detected", ""
649
+
650
+ # Save the audio temporarily
651
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
652
+ temp_file.close()
653
+
654
+ # Save the audio data to the temporary file
655
+ sf.write(temp_file.name, audio_samples, sample_rate)
656
+
657
+ # Use speech recognition on the file
658
+ recognizer = sr.Recognizer()
659
+ with sr.AudioFile(temp_file.name) as source:
660
+ audio = recognizer.record(source)
661
+
662
+ text = recognizer.recognize_google(audio, language=language)
663
+ print(f"Recognized: {text}")
664
+ return text, text
665
+
666
+ except sr.UnknownValueError:
667
+ print("Speech recognition could not understand audio")
668
+ return "Could not understand audio", ""
669
+ except sr.RequestError as e:
670
+ print(f"Could not request results from Google Speech Recognition service: {e}")
671
+ return f"Speech recognition service error: {str(e)}", ""
672
+ except Exception as e:
673
+ print(f"Error processing audio: {e}")
674
+ traceback.print_exc()
675
+ return f"Error processing audio: {str(e)}", ""
676
+ finally:
677
+ # Clean up temporary file
678
+ if 'temp_file' in locals() and os.path.exists(temp_file.name):
679
+ try:
680
+ os.unlink(temp_file.name)
681
+ except Exception as e:
682
+ print(f"Error deleting temporary file: {e}")
683
+
684
+ def save_reference_voice(self, audio_data, reference_text):
685
+ """Save the reference voice for future TTS generation"""
686
+ if audio_data is None or not reference_text.strip():
687
+ return "Error: Both reference audio and text are required"
688
+
689
+ self.saved_voice = audio_data
690
+ self.saved_voice_text = reference_text.strip()
691
+
692
+ # Clear TTS cache when voice changes
693
+ self.tts_cache.clear()
694
+
695
+ # Debug info
696
+ sample_rate, audio_samples = audio_data
697
+ print(f"Saved reference voice: {len(audio_samples)} samples at {sample_rate}Hz")
698
+ print(f"Reference text: {reference_text}")
699
+
700
+ return f"Voice saved successfully! Reference text: {reference_text}"
701
+
702
+ def process_text_input(self, text):
703
+ """Process text input from user"""
704
+ if text and text.strip():
705
+ return text, text
706
+ return "No input provided", ""
707
+
708
+ def generate_response(self, input_text):
709
+ """Generate AI response using GPT-3.5 Turbo"""
710
+ if not input_text or not input_text.strip():
711
+ return "ഇൻപുട്ട് ലഭിച്ചില്ല. വീണ്ടും ശ്രമിക്കുക.", None # "No input received. Please try again."
712
+
713
+ try:
714
+ # Prepare conversation context from history
715
+ messages = [{"role": "system", "content": self.system_prompt}]
716
+
717
+ # Add previous conversations for context
718
+ for entry in self.conversation_history:
719
+ role = "user" if entry["role"] == "user" else "assistant"
720
+ messages.append({"role": role, "content": entry["content"]})
721
+
722
+ # Add current input
723
+ messages.append({"role": "user", "content": input_text})
724
+
725
+ # Call OpenAI API
726
+ response = openai.ChatCompletion.create(
727
+ model="gpt-3.5-turbo",
728
+ messages=messages,
729
+ max_tokens=500,
730
+ temperature=0.7
731
+ )
732
+
733
+ response_text = response.choices[0].message["content"].strip()
734
+ return response_text, None
735
+
736
+ except Exception as e:
737
+ error_msg = f"എറർ: GPT മോഡലിൽ നിന്ന് ഉത്തരം ലഭിക്കുന്നതിൽ പ്രശ്നമുണ്ടായി: {str(e)}"
738
+ print(f"Error in GPT response: {e}")
739
+ traceback.print_exc()
740
+ return error_msg, None
741
+
742
+ def resample_audio(self, audio, orig_sr, target_sr):
743
+ """Resample audio to match target sample rate only if necessary"""
744
+ if orig_sr != target_sr:
745
+ print(f"Resampling audio from {orig_sr}Hz to {target_sr}Hz")
746
+ return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
747
+ return audio
748
+
749
+ def _generate_tts(self, text):
750
+ """Internal method to generate TTS without threading"""
751
+ if not text or not text.strip():
752
+ print("No text provided for TTS generation")
753
+ return None
754
+
755
+ # Check cache first
756
+ if text in self.tts_cache:
757
+ print("Using cached TTS output")
758
+ return self.tts_cache[text]
759
+
760
+ try:
761
+ # Check if we have a saved voice and the TTS model
762
+ if self.saved_voice is not None and tts_model is not None:
763
+ sample_rate, audio_data = self.saved_voice
764
+
765
+ # Create a temporary file for the reference audio
766
+ ref_temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
767
+ ref_temp_file.close()
768
+ print(f"Saving reference audio to {ref_temp_file.name}")
769
+
770
+ # Save the reference audio data
771
+ sf.write(ref_temp_file.name, audio_data, sample_rate)
772
+
773
+ # Create a temporary file for the output audio
774
+ output_temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
775
+ output_temp_file.close()
776
+
777
+ try:
778
+ # Generate speech using IndicF5 - simplified approach from second file
779
+ print(f"Generating speech with IndicF5. Text: {text[:30]}...")
780
+ start_time = time.time()
781
+
782
+ # Use torch.no_grad() to save memory and computation
783
+ with torch.no_grad():
784
+ # Run the inference using the wrapper
785
+ synth_audio = tts_model_wrapper.generate(
786
+ text,
787
+ ref_audio_path=ref_temp_file.name,
788
+ ref_text=self.saved_voice_text
789
+ )
790
+
791
+ end_time = time.time()
792
+ print(f"Speech generation completed in {end_time - start_time:.2f} seconds")
793
+
794
+ # Process audio for better quality
795
+ synth_audio = enhance_audio(synth_audio)
796
+
797
+ # Save the synthesized audio
798
+ sf.write(output_temp_file.name, synth_audio, 24000) # IndicF5 uses 24kHz
799
+
800
+ # Add to cache
801
+ self.tts_cache[text] = output_temp_file.name
802
+
803
+ print(f"TTS output saved to {output_temp_file.name}")
804
+ return output_temp_file.name
805
+
806
+ except Exception as e:
807
+ print(f"Error generating speech: {e}")
808
+ traceback.print_exc()
809
+ return None
810
+ finally:
811
+ # We don't delete the output file as it's returned to the caller
812
+ # But clean up reference file
813
+ try:
814
+ os.unlink(ref_temp_file.name)
815
+ except Exception as e:
816
+ print(f"Error cleaning up reference file: {e}")
817
+ else:
818
+ print("No saved voice reference or TTS model not loaded")
819
+ return None
820
+ except Exception as e:
821
+ print(f"Error in TTS processing: {e}")
822
+ traceback.print_exc()
823
+ return None
824
+
825
+ def queue_tts_generation(self, text, callback=None):
826
+ """Queue TTS generation in background thread"""
827
+ print(f"Queueing TTS generation for text: {text[:30]}...")
828
+ self.tts_queue.put((text, callback))
829
+
830
+ def generate_streamed_speech(self, text):
831
+ """Generate speech in a streaming manner for low latency"""
832
+ if not self.saved_voice:
833
+ print("No reference voice saved")
834
+ return None
835
+
836
+ if not text or not text.strip():
837
+ print("No text provided for streaming TTS")
838
+ return None
839
+
840
+ sample_rate, audio_data = self.saved_voice
841
+
842
+ # Start streaming generation
843
+ self.streaming_tts.generate(
844
+ text=text,
845
+ ref_audio=audio_data,
846
+ ref_sr=sample_rate,
847
+ ref_text=self.saved_voice_text
848
+ )
849
+
850
+ # Return the path that will be populated
851
+ return self.streaming_tts.output_file
852
+
853
+ def update_history(self, user_input, ai_response):
854
+ """Update conversation history"""
855
+ if user_input and user_input.strip():
856
+ self.conversation_history.append({"role": "user", "content": user_input})
857
+
858
+ if ai_response and ai_response.strip():
859
+ self.conversation_history.append({"role": "assistant", "content": ai_response})
860
+
861
+ # Limit history size
862
+ if len(self.conversation_history) > 20:
863
+ self.conversation_history = self.conversation_history[-20:]
864
+
865
+ # Initialize global conversation engine
866
+ conversation_engine = ConversationEngine()
867
+ speech_recognizer = SpeechRecognizer()
868
+
869
+ class ConversationEngine:
870
+ def __init__(self):
871
+ self.conversation_history = []
872
+ self.system_prompt = "You are a helpful assistant that speaks Malayalam fluently. Always respond in Malayalam script with proper formatting."
873
+ self.saved_voice = None
874
+ self.saved_voice_text = ""
875
+ self.tts_cache = {} # Cache for TTS outputs
876
+
877
+ # TTS background processing queue
878
+ self.tts_queue = queue.Queue()
879
+ self.tts_thread = threading.Thread(target=self.tts_worker, daemon=True)
880
+ self.tts_thread.start()
881
+
882
+ # Initialize IndicF5 TTS model if available
883
+ self.tts_model = None
884
+ self.device = None
885
+ try:
886
+ self.initialize_tts_model()
887
+
888
+ # Test the model if it was loaded successfully
889
+ if self.tts_model is not None:
890
+ print("TTS model initialized successfully")
891
+ except Exception as e:
892
+ print(f"Error initializing TTS model: {e}")
893
+ traceback.print_exc()
894
+
895
+ def initialize_tts_model(self):
896
+ """Initialize the IndicF5 TTS model with optimizations"""
897
+ try:
898
+ # Check for HF token in environment and use it if available
899
+ hf_token = os.getenv("HF_TOKEN")
900
+ if hf_token:
901
+ print("Logging into Hugging Face with the provided token.")
902
+ login(token=hf_token)
903
+
904
+ if torch.cuda.is_available():
905
+ self.device = torch.device("cuda")
906
+ print(f"Using GPU: {torch.cuda.get_device_name(0)}")
907
+ else:
908
+ self.device = torch.device("cpu")
909
+ print("Using CPU")
910
+
911
+ # Enable performance optimizations
912
+ torch.backends.cudnn.benchmark = True
913
+
914
+ # Load TTS model and move it to the appropriate device (GPU/CPU)
915
+ print("Loading TTS model from ai4bharat/IndicF5...")
916
+ repo_id = "ai4bharat/IndicF5"
917
+ self.tts_model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
918
+ self.tts_model = self.tts_model.to(self.device)
919
+
920
+ # Set model to evaluation mode for faster inference
921
+ self.tts_model.eval()
922
+ print("TTS model loaded successfully")
923
+ except Exception as e:
924
+ print(f"Failed to load TTS model: {e}")
925
+ self.tts_model = None
926
+ traceback.print_exc()
927
+
928
+ def tts_worker(self):
929
+ """Background worker to process TTS requests"""
930
+ while True:
931
+ try:
932
+ # Get text and callback from queue
933
+ text, callback = self.tts_queue.get()
934
+
935
+ # Generate speech
936
+ audio_path = self._generate_tts(text)
937
+
938
+ # Execute callback with result
939
+ if callback:
940
+ callback(audio_path)
941
+
942
+ # Mark task as done
943
+ self.tts_queue.task_done()
944
+ except Exception as e:
945
+ print(f"Error in TTS worker: {e}")
946
+ traceback.print_exc()
947
+
948
+ def transcribe_audio(self, audio_data, language="ml-IN"):
949
+ """Convert audio to text using speech recognition"""
950
+ if audio_data is None:
951
+ print("No audio data received")
952
+ return "No audio detected", ""
953
+
954
+ # Make sure we have audio data in the expected format
955
+ try:
956
+ if isinstance(audio_data, tuple) and len(audio_data) == 2:
957
+ # Expected format: (sample_rate, audio_samples)
958
+ sample_rate, audio_samples = audio_data
959
+ else:
960
+ print(f"Unexpected audio format: {type(audio_data)}")
961
+ return "Invalid audio format", ""
962
+
963
+ if len(audio_samples) == 0:
964
+ print("Empty audio samples")
965
+ return "No speech detected", ""
966
+
967
+ # Save the audio temporarily
968
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
969
+ temp_file.close()
970
+
971
+ # Save the audio data to the temporary file
972
+ sf.write(temp_file.name, audio_samples, sample_rate)
973
+
974
+ # Use speech recognition on the file
975
+ recognizer = sr.Recognizer()
976
+ with sr.AudioFile(temp_file.name) as source:
977
+ audio = recognizer.record(source)
978
+
979
+ text = recognizer.recognize_google(audio, language=language)
980
+ print(f"Recognized: {text}")
981
+ return text, text
982
+
983
+ except sr.UnknownValueError:
984
+ print("Speech recognition could not understand audio")
985
+ return "Could not understand audio", ""
986
+ except sr.RequestError as e:
987
+ print(f"Could not request results from Google Speech Recognition service: {e}")
988
+ return f"Speech recognition service error: {str(e)}", ""
989
+ except Exception as e:
990
+ print(f"Error processing audio: {e}")
991
+ traceback.print_exc()
992
+ return f"Error processing audio: {str(e)}", ""
993
+ finally:
994
+ # Clean up temporary file
995
+ if 'temp_file' in locals() and os.path.exists(temp_file.name):
996
+ try:
997
+ os.unlink(temp_file.name)
998
+ except Exception as e:
999
+ print(f"Error deleting temporary file: {e}")
1000
+
1001
+ def save_reference_voice(self, audio_data, reference_text):
1002
+ """Save the reference voice for future TTS generation"""
1003
+ if audio_data is None or not reference_text.strip():
1004
+ return "Error: Both reference audio and text are required"
1005
+
1006
+ self.saved_voice = audio_data
1007
+ self.saved_voice_text = reference_text.strip()
1008
+
1009
+ # Clear TTS cache when voice changes
1010
+ self.tts_cache.clear()
1011
+
1012
+ # Debug info
1013
+ sample_rate, audio_samples = audio_data
1014
+ print(f"Saved reference voice: {len(audio_samples)} samples at {sample_rate}Hz")
1015
+ print(f"Reference text: {reference_text}")
1016
+
1017
+ return f"Voice saved successfully! Reference text: {reference_text}"
1018
+
1019
+ def process_text_input(self, text):
1020
+ """Process text input from user"""
1021
+ if text and text.strip():
1022
+ return text, text
1023
+ return "No input provided", ""
1024
+
1025
+ def generate_response(self, input_text):
1026
+ """Generate AI response using GPT-3.5 Turbo"""
1027
+ if not input_text or not input_text.strip():
1028
+ return "ഇൻപുട്ട് ലഭിച്ചില്ല. വീണ്ടും ശ്രമ���ക്കുക.", None # "No input received. Please try again."
1029
+
1030
+ try:
1031
+ # Prepare conversation context from history
1032
+ messages = [{"role": "system", "content": self.system_prompt}]
1033
+
1034
+ # Add previous conversations for context
1035
+ for entry in self.conversation_history:
1036
+ role = "user" if entry["role"] == "user" else "assistant"
1037
+ messages.append({"role": role, "content": entry["content"]})
1038
+
1039
+ # Add current input
1040
+ messages.append({"role": "user", "content": input_text})
1041
+
1042
+ # Call OpenAI API
1043
+ response = openai.ChatCompletion.create(
1044
+ model="gpt-3.5-turbo",
1045
+ messages=messages,
1046
+ max_tokens=500,
1047
+ temperature=0.7
1048
+ )
1049
+
1050
+ response_text = response.choices[0].message.content.strip()
1051
+ return response_text, None
1052
+
1053
+ except Exception as e:
1054
+ error_msg = f"എറർ: GPT മോഡലിൽ നിന്ന് ഉത്തരം ലഭിക്കുന്നതിൽ പ്രശ്നമുണ്ടായി: {str(e)}"
1055
+ print(f"Error in GPT response: {e}")
1056
+ traceback.print_exc()
1057
+ return error_msg, None
1058
+
1059
+ def resample_audio(self, audio, orig_sr, target_sr):
1060
+ """Resample audio to match target sample rate only if necessary"""
1061
+ if orig_sr != target_sr:
1062
+ print(f"Resampling audio from {orig_sr}Hz to {target_sr}Hz")
1063
+ return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
1064
+ return audio
1065
+
1066
+ def _generate_tts(self, text):
1067
+ """Internal method to generate TTS without threading"""
1068
+ if not text or not text.strip():
1069
+ print("No text provided for TTS generation")
1070
+ return None
1071
+
1072
+ # Check cache first
1073
+ if text in self.tts_cache:
1074
+ print("Using cached TTS output")
1075
+ return self.tts_cache[text]
1076
+
1077
+ try:
1078
+ # Check if we have a saved voice and the TTS model
1079
+ if self.saved_voice is not None and self.tts_model is not None:
1080
+ sample_rate, audio_data = self.saved_voice
1081
+
1082
+ # Create a temporary file for the reference audio
1083
+ ref_temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
1084
+ ref_temp_file.close()
1085
+ print(f"Saving reference audio to {ref_temp_file.name}")
1086
+
1087
+ # Save the reference audio data
1088
+ sf.write(ref_temp_file.name, audio_data, sample_rate)
1089
+
1090
+ # Create a temporary file for the output audio
1091
+ output_temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
1092
+ output_temp_file.close()
1093
+
1094
+ try:
1095
+ # Generate speech using IndicF5 - simplified approach from second file
1096
+ print(f"Generating speech with IndicF5. Text: {text[:30]}...")
1097
+ start_time = time.time()
1098
+
1099
+ # Use torch.no_grad() to save memory and computation
1100
+ with torch.no_grad():
1101
+ # Run the inference - directly use the model as in the second file
1102
+ synth_audio = self.tts_model(
1103
+ text,
1104
+ ref_audio_path=ref_temp_file.name,
1105
+ ref_text=self.saved_voice_text
1106
+ )
1107
+
1108
+ end_time = time.time()
1109
+ print(f"Speech generation completed in {(end_time - start_time)} seconds")
1110
+
1111
+ # Normalize output if needed
1112
+ if synth_audio.dtype == np.int16:
1113
+ synth_audio = synth_audio.astype(np.float32) / 32768.0
1114
+
1115
+ # Resample the generated audio to match the reference audio's sample rate
1116
+ synth_audio = self.resample_audio(synth_audio, orig_sr=24000, target_sr=sample_rate)
1117
+
1118
+ # Save the synthesized audio
1119
+ print(f"Saving synthesized audio to {output_temp_file.name}")
1120
+ sf.write(output_temp_file.name, synth_audio, sample_rate)
1121
+
1122
+ # Cache the result
1123
+ self.tts_cache[text] = output_temp_file.name
1124
+
1125
+ print(f"TTS generation successful, output file: {output_temp_file.name}")
1126
+ return output_temp_file.name
1127
+ except Exception as e:
1128
+ print(f"IndicF5 TTS failed with error: {e}")
1129
+ traceback.print_exc()
1130
+ # Fall back to Google TTS
1131
+ return self.fallback_tts(text, output_temp_file.name)
1132
+ finally:
1133
+ # Clean up reference audio file
1134
+ if os.path.exists(ref_temp_file.name):
1135
+ try:
1136
+ os.unlink(ref_temp_file.name)
1137
+ except Exception as e:
1138
+ print(f"Error deleting temporary file: {e}")
1139
+ else:
1140
+ if self.saved_voice is None:
1141
+ print("No saved voice available for TTS")
1142
+ if self.tts_model is None:
1143
+ print("TTS model not initialized")
1144
+
1145
+ # No saved voice or TTS model, use fallback
1146
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
1147
+ temp_file.close()
1148
+ return self.fallback_tts(text, temp_file.name)
1149
+
1150
+ except Exception as e:
1151
+ print(f"Error in TTS processing: {e}")
1152
+ traceback.print_exc()
1153
+ return None
1154
+
1155
+ def speak_with_indicf5(self, text, callback=None):
1156
+ """Queue text for TTS generation"""
1157
+ if not text or not text.strip():
1158
+ if callback:
1159
+ callback(None)
1160
+ return None
1161
+
1162
+ # Check cache first for immediate response
1163
+ if text in self.tts_cache:
1164
+ print("Using cached TTS output")
1165
+ if callback:
1166
+ callback(self.tts_cache[text])
1167
+ return self.tts_cache[text]
1168
+
1169
+ # If no callback provided, generate synchronously
1170
+ if callback is None:
1171
+ return self._generate_tts(text)
1172
+
1173
+ # Otherwise, queue for async processing
1174
+ self.tts_queue.put((text, callback))
1175
+ return None
1176
+
1177
+ def fallback_tts(self, text, output_path):
1178
+ """Fallback to Google TTS if IndicF5 fails"""
1179
+ try:
1180
+ from gtts import gTTS
1181
+
1182
+ # Determine if text is Malayalam
1183
+ is_malayalam = any('\u0D00' <= c <= '\u0D7F' for c in text)
1184
+ lang = 'ml' if is_malayalam else 'en'
1185
+
1186
+ print(f"Using fallback Google TTS with language: {lang}")
1187
+ tts = gTTS(text=text, lang=lang, slow=False)
1188
+ tts.save(output_path)
1189
+
1190
+ # Cache the result
1191
+ self.tts_cache[text] = output_path
1192
+ print(f"Fallback TTS saved to: {output_path}")
1193
+
1194
+ return output_path
1195
+ except Exception as e:
1196
+ print(f"Fallback TTS also failed: {e}")
1197
+ traceback.print_exc()
1198
+ return None
1199
+
1200
+ def add_message(self, role, content):
1201
+ """Add a message to the conversation history"""
1202
+ timestamp = datetime.now().strftime("%H:%M:%S")
1203
+ self.conversation_history.append({
1204
+ "role": role,
1205
+ "content": content,
1206
+ "timestamp": timestamp
1207
+ })
1208
+
1209
+ def clear_conversation(self):
1210
+ """Clear the conversation history"""
1211
+ self.conversation_history = []
1212
+
1213
+ def cleanup(self):
1214
+ """Clean up resources when shutting down"""
1215
+ print("Cleaning up resources...")
1216
+
1217
+ # Load example Malayalam voices
1218
+ def load_audio_from_url(url):
1219
+ """Load audio from a URL"""
1220
+ try:
1221
+ response = requests.get(url)
1222
+ if response.status_code == 200:
1223
+ audio_data, sample_rate = sf.read(io.BytesIO(response.content))
1224
+ return sample_rate, audio_data
1225
+ except Exception as e:
1226
+ print(f"Error loading audio from URL: {e}")
1227
+ return None, None
1228
+
1229
+ # Malayalam voice examples
1230
+ EXAMPLE_VOICES = [
1231
+ {
1232
+ "name": "Aparna Voice",
1233
+ "url": "https://raw.githubusercontent.com/Aparna0112/voicerecording-_TTS/main/Aparna%20Voice.wav",
1234
+ "transcript": "ഞാൻ ഒരു ഫോണിന്‍റെ കവർ നോക്കുകയാണ്. എനിക്ക് സ്മാർട്ട് ഫോണിന് കവർ വേണം"
1235
+ },
1236
+ {
1237
+ "name": "KC Voice",
1238
+ "url": "https://raw.githubusercontent.com/Aparna0112/voicerecording-_TTS/main/KC%20Voice.wav",
1239
+ "transcript": "ഹലോ ഇത് അപരനെ അല്ലേ ഞാൻ ജഗദീപ് ആണ് വിളിക്കുന്നത് ഇപ്പോൾ ഫ്രീയാണോ സംസാരിക്കാമോ"
1240
+ }
1241
+ ]
1242
+
1243
+ # Preload example voices
1244
+ for voice in EXAMPLE_VOICES:
1245
+ sample_rate, audio_data = load_audio_from_url(voice["url"])
1246
+ if sample_rate is not None:
1247
+ voice["audio"] = (sample_rate, audio_data)
1248
+ print(f"Loaded example voice: {voice['name']}")
1249
+ else:
1250
+ print(f"Failed to load voice: {voice['name']}")
1251
+
1252
+ def create_chatbot_interface():
1253
+ """Create a single-page chatbot interface with voice input, output, and voice selection"""
1254
+
1255
+ # Initialize conversation engine
1256
+ engine = ConversationEngine()
1257
+
1258
+ # CSS for styling the chat interface
1259
+ css = """
1260
+ .chatbot-container {
1261
+ display: flex;
1262
+ flex-direction: column;
1263
+ height: 100%;
1264
+ max-width: 800px;
1265
+ margin: 0 auto;
1266
+ }
1267
+ .chat-window {
1268
+ flex-grow: 1;
1269
+ overflow-y: auto;
1270
+ padding: 1rem;
1271
+ background: #f5f7f9;
1272
+ border-radius: 0.5rem;
1273
+ margin-bottom: 1rem;
1274
+ min-height: 400px;
1275
+ }
1276
+ .input-area {
1277
+ display: flex;
1278
+ gap: 0.5rem;
1279
+ padding: 0.5rem;
1280
+ align-items: center;
1281
+ }
1282
+ .message {
1283
+ margin-bottom: 1rem;
1284
+ padding: 0.8rem;
1285
+ border-radius: 0.5rem;
1286
+ position: relative;
1287
+ max-width: 80%;
1288
+ }
1289
+ .user-message {
1290
+ background: #e1f5fe;
1291
+ align-self: flex-end;
1292
+ margin-left: auto;
1293
+ }
1294
+ .bot-message {
1295
+ background: #f0f0f0;
1296
+ align-self: flex-start;
1297
+ }
1298
+ .timestamp {
1299
+ font-size: 0.7rem;
1300
+ color: #888;
1301
+ margin-top: 0.2rem;
1302
+ text-align: right;
1303
+ }
1304
+ .chatbot-header {
1305
+ text-align: center;
1306
+ color: #333;
1307
+ margin-bottom: 1rem;
1308
+ }
1309
+ .chat-controls {
1310
+ display: flex;
1311
+ justify-content: space-between;
1312
+ margin-bottom: 0.5rem;
1313
+ }
1314
+ .voice-selector {
1315
+ background: #f8f9fa;
1316
+ padding: 1rem;
1317
+ border-radius: 0.5rem;
1318
+ margin-bottom: 1rem;
1319
+ }
1320
+ .progress-bar {
1321
+ height: 4px;
1322
+ background-color: #e0e0e0;
1323
+ position: relative;
1324
+ margin: 10px 0;
1325
+ border-radius: 2px;
1326
+ }
1327
+ .progress-bar-fill {
1328
+ height: 100%;
1329
+ background-color: #4CAF50;
1330
+ border-radius: 2px;
1331
+ transition: width 0.3s ease-in-out;
1332
+ }
1333
+ """
1334
+
1335
+ with gr.Blocks(css=css, title="Malayalam Voice Chatbot") as interface:
1336
+ gr.Markdown("# 🤖 Malayalam Voice Chatbot with Voice Selection", elem_classes=["chatbot-header"])
1337
+
1338
+ # Create a state variable for TTS progress
1339
+ tts_progress_state = gr.State(0)
1340
+ audio_output_state = gr.State(None)
1341
+
1342
+ with gr.Row(elem_classes=["chatbot-container"]):
1343
+ with gr.Column():
1344
+ # Voice selection section - fixed to use Accordion instead of Box
1345
+ with gr.Accordion("🎤 Voice Selection", open=True):
1346
+ # Select from example voices or record your own
1347
+ voice_selector = gr.Dropdown(
1348
+ choices=[voice["name"] for voice in EXAMPLE_VOICES],
1349
+ value=EXAMPLE_VOICES[0]["name"] if EXAMPLE_VOICES else None,
1350
+ label="Select Voice Example"
1351
+ )
1352
+
1353
+ # Display selected voice info
1354
+ voice_info = gr.Textbox(
1355
+ value=EXAMPLE_VOICES[0]["transcript"] if EXAMPLE_VOICES else "",
1356
+ label="Voice Sample Transcript",
1357
+ lines=2,
1358
+ interactive=True
1359
+ )
1360
+
1361
+ # Play selected example voice
1362
+ example_audio = gr.Audio(
1363
+ value=None,
1364
+ label="Example Voice",
1365
+ interactive=False
1366
+ )
1367
+
1368
+ # Or record your own voice
1369
+ gr.Markdown("### OR Record Your Own Voice")
1370
+
1371
+ custom_voice = gr.Audio(
1372
+ sources=["microphone", "upload"],
1373
+ type="numpy",
1374
+ label="Record/Upload Your Voice"
1375
+ )
1376
+
1377
+ custom_transcript = gr.Textbox(
1378
+ value="",
1379
+ label="Your Voice Transcript (what you said in Malayalam)",
1380
+ lines=2
1381
+ )
1382
+
1383
+ # Button to save the selected/recorded voice
1384
+ save_voice_btn = gr.Button("💾 Save Voice for Chat", variant="primary")
1385
+ voice_status = gr.Textbox(label="Voice Status", value="No voice saved yet")
1386
+
1387
+ # Language selector and controls for chat
1388
+ with gr.Row(elem_classes=["chat-controls"]):
1389
+ language_selector = gr.Dropdown(
1390
+ choices=["ml-IN", "en-US", "hi-IN", "ta-IN", "te-IN", "kn-IN"],
1391
+ value="ml-IN",
1392
+ label="Speech Recognition Language"
1393
+ )
1394
+ clear_btn = gr.Button("🧹 Clear Chat", scale=0)
1395
+
1396
+ # Chat display area
1397
+ chatbot = gr.Chatbot(
1398
+ [],
1399
+ elem_id="chatbox",
1400
+ bubble_full_width=False,
1401
+ height=450,
1402
+ elem_classes=["chat-window"]
1403
+ )
1404
+
1405
+ # Progress bar for TTS generation
1406
+ with gr.Row():
1407
+ tts_progress = gr.Slider(
1408
+ minimum=0,
1409
+ maximum=100,
1410
+ value=0,
1411
+ label="TTS Progress",
1412
+ interactive=False
1413
+ )
1414
+
1415
+ # Audio output for the bot's response
1416
+ audio_output = gr.Audio(
1417
+ label="Bot's Voice Response",
1418
+ type="filepath",
1419
+ autoplay=True,
1420
+ visible=True
1421
+ )
1422
+
1423
+ # Status message for debugging
1424
+ status_msg = gr.Textbox(
1425
+ label="Status",
1426
+ value="Ready",
1427
+ interactive=False
1428
+ )
1429
+
1430
+ # Input area with separate components
1431
+ with gr.Row(elem_classes=["input-area"]):
1432
+ audio_msg = gr.Textbox(
1433
+ label="Message",
1434
+ placeholder="Type a message or record audio",
1435
+ lines=1
1436
+ )
1437
+ audio_input = gr.Audio(
1438
+ sources=["microphone"],
1439
+ type="numpy",
1440
+ label="Record",
1441
+ elem_classes=["audio-input"]
1442
+ )
1443
+ submit_btn = gr.Button("🚀 Send", variant="primary")
1444
+
1445
+ # Function to update voice example info
1446
+ def update_voice_example(voice_name):
1447
+ for voice in EXAMPLE_VOICES:
1448
+ if voice["name"] == voice_name and "audio" in voice:
1449
+ return voice["transcript"], voice["audio"]
1450
+ return "", None
1451
+
1452
+ # Function to save voice for TTS
1453
+ def save_voice_for_tts(example_name, example_audio, custom_audio, example_transcript, custom_transcript):
1454
+ try:
1455
+ # Check if we're using an example voice or custom recorded voice
1456
+ if custom_audio is not None:
1457
+ # Use custom recorded voice
1458
+ if not custom_transcript.strip():
1459
+ return "Error: Please provide a transcript for your recorded voice"
1460
+
1461
+ voice_audio = custom_audio
1462
+ transcript = custom_transcript
1463
+ source = "custom recording"
1464
+ elif example_audio is not None:
1465
+ # Use selected example voice
1466
+ voice_audio = example_audio
1467
+ transcript = example_transcript
1468
+ source = f"example: {example_name}"
1469
+ else:
1470
+ return "Error: No voice selected or recorded"
1471
+
1472
+ # Save the voice in the engine
1473
+ result = engine.save_reference_voice(voice_audio, transcript)
1474
+
1475
+ return f"Voice saved successfully! Using {source}"
1476
+ except Exception as e:
1477
+ print(f"Error saving voice: {e}")
1478
+ traceback.print_exc()
1479
+ return f"Error saving voice: {str(e)}"
1480
+
1481
+ # Function to update TTS progress
1482
+ def update_tts_progress(progress):
1483
+ return progress
1484
+
1485
+ # Audio generated callback
1486
+ def on_tts_generated(audio_path):
1487
+ print(f"TTS generation callback received path: {audio_path}")
1488
+ return audio_path, 100, "Response ready" # audio path, 100% progress, status message
1489
+
1490
+ # Function to process user input and generate response
1491
+ def process_input(audio, text_input, history, language, progress):
1492
+ try:
1493
+ # Update status
1494
+ status = "Processing input..."
1495
+
1496
+ # Reset progress bar
1497
+ progress = 0
1498
+
1499
+ # Check which input mode we're using
1500
+ if audio is not None:
1501
+ # Audio input
1502
+ transcribed_text, input_text = engine.transcribe_audio(audio, language)
1503
+ if not input_text:
1504
+ status = "Could not understand audio. Please try again."
1505
+ return history, None, status, text_input, progress
1506
+ elif text_input and text_input.strip():
1507
+ # Text input
1508
+ input_text = text_input.strip()
1509
+ transcribed_text = input_text
1510
+ else:
1511
+ # No valid input
1512
+ status = "No input detected. Please speak or type a message."
1513
+ return history, None, status, text_input, progress
1514
+
1515
+ # Add user message to conversation history
1516
+ engine.add_message("user", input_text)
1517
+
1518
+ # Update the Gradio chatbot display immediately with user message
1519
+ updated_history = history + [[transcribed_text, None]]
1520
+
1521
+ # Update status and progress
1522
+ status = "Generating response..."
1523
+ progress = 30
1524
+
1525
+ # Generate response
1526
+ response_text, _ = engine.generate_response(input_text)
1527
+
1528
+ # Add assistant response to conversation history
1529
+ engine.add_message("assistant", response_text)
1530
+
1531
+ # Update the Gradio chatbot with the assistant's response
1532
+ updated_history = history + [[transcribed_text, response_text]]
1533
+
1534
+ # Update status and progress
1535
+ status = "Generating speech..."
1536
+ progress = 60
1537
+
1538
+ # Generate speech for response synchronously (for better debugging)
1539
+ audio_path = engine._generate_tts(response_text)
1540
+
1541
+ if audio_path:
1542
+ status = f"Response ready: {audio_path}"
1543
+ progress = 100
1544
+ print(f"Audio generated successfully: {audio_path}")
1545
+ else:
1546
+ status = "Failed to generate speech"
1547
+
1548
+ # Clear the text input
1549
+ return updated_history, audio_path, status, "", progress
1550
+
1551
+ except Exception as e:
1552
+ # Catch any unexpected errors
1553
+ error_message = f"Error: {str(e)}"
1554
+ print(error_message)
1555
+ traceback.print_exc()
1556
+ return history, None, error_message, text_input, progress
1557
+
1558
+ # Function to clear chat history
1559
+ def clear_chat():
1560
+ engine.clear_conversation()
1561
+ return [], None, "Chat history cleared", "", 0
1562
+
1563
+ # Connect event handlers
1564
+
1565
+ # Voice selection handlers
1566
+ voice_selector.change(
1567
+ update_voice_example,
1568
+ inputs=[voice_selector],
1569
+ outputs=[voice_info, example_audio]
1570
+ )
1571
+
1572
+ # Save voice button handler
1573
+ save_voice_btn.click(
1574
+ save_voice_for_tts,
1575
+ inputs=[voice_selector, example_audio, custom_voice, voice_info, custom_transcript],
1576
+ outputs=[voice_status]
1577
+ )
1578
+
1579
+ # Chat handlers
1580
+ submit_btn.click(
1581
+ process_input,
1582
+ inputs=[audio_input, audio_msg, chatbot, language_selector, tts_progress_state],
1583
+ outputs=[chatbot, audio_output, status_msg, audio_msg, tts_progress]
1584
+ )
1585
+
1586
+ # Allow sending by pressing Enter key in the text input
1587
+ audio_msg.submit(
1588
+ process_input,
1589
+ inputs=[audio_input, audio_msg, chatbot, language_selector, tts_progress_state],
1590
+ outputs=[chatbot, audio_output, status_msg, audio_msg, tts_progress]
1591
+ )
1592
+
1593
+ # Clear button handler
1594
+ clear_btn.click(
1595
+ clear_chat,
1596
+ inputs=[],
1597
+ outputs=[chatbot, audio_output, status_msg, audio_msg, tts_progress]
1598
+ )
1599
+
1600
+ # Setup cleanup on exit
1601
+ def exit_handler():
1602
+ engine.cleanup()
1603
+
1604
+ import atexit
1605
+ atexit.register(exit_handler)
1606
+
1607
+ # Enable queueing for better responsiveness
1608
+ interface.queue()
1609
+
1610
+ return interface
1611
+
1612
+ # Start the interface
1613
+ if __name__ == "__main__":
1614
+ print("Starting Malayalam Voice Chatbot with IndicF5 Voice Selection...")
1615
+ interface = create_chatbot_interface()
1616
+ interface.launch(debug=True) # Enable debug mode to see errors in the console