Devakumar868 commited on
Commit
daf7e26
Β·
verified Β·
1 Parent(s): 2bc37a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +340 -160
app.py CHANGED
@@ -1,187 +1,367 @@
1
- import gradio as gr
 
 
2
  import torch
3
  import numpy as np
4
- from dia.model import Dia
5
- import warnings
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # Suppress warnings for cleaner output
8
- warnings.filterwarnings("ignore", category=FutureWarning)
9
- warnings.filterwarnings("ignore", category=UserWarning)
 
 
 
10
 
11
- # Global model variable
12
- model = None
 
13
 
14
- def load_model_once():
15
- """Load the Dia model once and cache it globally"""
16
- global model
17
- if model is None:
18
- print("Loading Dia model... This may take a few minutes.")
 
 
 
 
 
 
19
  try:
20
- # Load model without trying to move it manually to GPU
21
- # The Dia model handles GPU placement internally
22
- model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32")
23
-
24
- print("Model loaded successfully!")
25
- if torch.cuda.is_available():
26
- print(f"CUDA is available: {torch.cuda.get_device_name()}")
27
- else:
28
- print("CUDA is not available, using CPU")
29
-
30
  except Exception as e:
31
- print(f"Error loading model: {e}")
32
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- return model
35
-
36
- def generate_audio(text, seed=42):
37
- """Generate audio from text input with error handling"""
38
- try:
39
- # Clear GPU cache before generation
40
- if torch.cuda.is_available():
 
 
 
41
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- current_model = load_model_once()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Validate input
46
- if not text or not text.strip():
47
- return None, "❌ Please enter some text"
 
 
 
 
 
 
 
 
48
 
49
- # Clean and format text
50
- text = text.strip()
51
- if not text.startswith('[S1]') and not text.startswith('[S2]'):
52
- text = '[S1] ' + text
53
 
54
- # Set seed for reproducibility
55
- if seed:
56
- torch.manual_seed(int(seed))
57
- if torch.cuda.is_available():
58
- torch.cuda.manual_seed(int(seed))
59
-
60
- print(f"Generating speech for: {text[:100]}...")
61
-
62
- # Generate audio - disable torch compile for T4 stability
63
- with torch.no_grad():
64
- audio_output = current_model.generate(
65
- text,
66
- use_torch_compile=False, # Disabled for T4 compatibility
67
- verbose=False
68
- )
 
 
 
69
 
70
- # Ensure audio_output is numpy array
71
- if isinstance(audio_output, torch.Tensor):
72
- audio_output = audio_output.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- # Normalize audio to prevent clipping
75
- if len(audio_output) > 0:
76
- max_val = np.max(np.abs(audio_output))
77
- if max_val > 1.0:
78
- audio_output = audio_output / max_val * 0.95
79
 
80
- print("βœ… Audio generated successfully!")
81
- return (44100, audio_output), "βœ… Audio generated successfully!"
 
 
82
 
83
- except torch.cuda.OutOfMemoryError:
84
- # Handle GPU memory issues
85
- if torch.cuda.is_available():
86
- torch.cuda.empty_cache()
87
- error_msg = "❌ GPU memory error. Try shorter text or restart the space."
88
- print(error_msg)
89
- return None, error_msg
90
-
91
- except Exception as e:
92
- error_msg = f"❌ Error: {str(e)}"
93
- print(error_msg)
94
- return None, error_msg
95
-
96
- # Create the Gradio interface
97
- demo = gr.Blocks(title="Dia TTS Demo")
98
-
99
- with demo:
100
- gr.HTML("""
101
- <div style="text-align: center; padding: 20px;">
102
- <h1>πŸŽ™οΈ Dia TTS - Ultra-Realistic Text-to-Speech</h1>
103
- <p style="font-size: 18px; color: #666;">
104
- Generate multi-speaker, emotion-aware dialogue using the Dia 1.6B model
105
- </p>
106
- </div>
107
- """)
108
 
109
- with gr.Row():
110
- with gr.Column():
111
- text_input = gr.Textbox(
112
- label="πŸ“ Text to Speech",
113
- placeholder="[S1] Hello there! How are you today? [S2] I'm doing great, thanks for asking! (laughs)",
114
- lines=6,
115
- value="[S1] Welcome to the Dia TTS demo! [S2] This is amazing technology!",
116
- info="Use [S1] and [S2] for different speakers. Add emotions like (laughs), (sighs), etc."
117
- )
 
118
 
119
- seed_input = gr.Number(
120
- label="🎲 Random Seed",
121
- value=42,
122
- precision=0,
123
- info="Same seed = consistent voices"
124
- )
125
 
126
- generate_btn = gr.Button("🎡 Generate Speech", variant="primary")
 
127
 
128
- with gr.Column():
129
- audio_output = gr.Audio(
130
- label="πŸ”Š Generated Audio",
131
- type="numpy"
132
- )
133
 
134
- status_text = gr.Textbox(
135
- label="πŸ“Š Status",
136
- interactive=False,
137
- lines=2
138
- )
 
 
 
 
 
 
139
 
140
- # Connect the button to the function
141
- generate_btn.click(
142
- fn=generate_audio,
143
- inputs=[text_input, seed_input],
144
- outputs=[audio_output, status_text]
145
- )
146
-
147
- # Add example buttons
148
- with gr.Row():
149
- example_btn1 = gr.Button("πŸ“» Podcast", size="sm")
150
- example_btn2 = gr.Button("πŸ˜„ Chat", size="sm")
151
- example_btn3 = gr.Button("🎭 Drama", size="sm")
152
-
153
- # Example button functions
154
- example_btn1.click(
155
- lambda: "[S1] Welcome to our podcast! [S2] Thanks for having me on the show!",
156
- outputs=text_input
157
- )
158
-
159
- example_btn2.click(
160
- lambda: "[S1] Did you see the game? [S2] Yes! (laughs) It was incredible!",
161
- outputs=text_input
162
- )
163
 
164
- example_btn3.click(
165
- lambda: "[S1] I can't believe you're leaving. (sighs) [S2] I know, it's hard. (sad)",
166
- outputs=text_input
167
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- # Usage instructions
170
- gr.HTML("""
171
- <div style="margin-top: 20px; padding: 15px; background: #f0f8ff; border-radius: 8px;">
172
- <h3>πŸ’‘ Usage Tips:</h3>
173
- <ul>
174
- <li><strong>Speaker Tags:</strong> Use [S1] and [S2] to switch between speakers</li>
175
- <li><strong>Emotions:</strong> Add (laughs), (sighs), (excited), (whispers), (sad), etc.</li>
176
- <li><strong>Length:</strong> Keep text moderate length (5-20 seconds of speech works best)</li>
177
- <li><strong>Seeds:</strong> Use the same seed number for consistent voice characteristics</li>
178
- </ul>
179
-
180
- <p><strong>Supported Emotions:</strong> (laughs), (sighs), (gasps), (excited), (sad), (angry),
181
- (surprised), (whispers), (shouts), (coughs), (clears throat), (sniffs), (chuckles), (groans)</p>
182
- </div>
183
- """)
184
 
185
- # Launch with basic configuration
186
  if __name__ == "__main__":
187
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import time
4
  import torch
5
  import numpy as np
6
+ import soundfile as sf
7
+ import gradio as gr
8
+ from transformers import (
9
+ AutoTokenizer,
10
+ AutoModelForCausalLM,
11
+ BitsAndBytesConfig,
12
+ pipeline
13
+ )
14
+ from TTS.api import TTS
15
+ import nemo.collections.asr as nemo_asr
16
+ from scipy.io.wavfile import write
17
+ import tempfile
18
+ import threading
19
+ import queue
20
 
21
+ # Configuration
22
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
+ SAMPLE_RATE = 22050
24
+ MAX_LENGTH = 512
25
+ TEMPERATURE = 0.7
26
+ SEED = 42
27
 
28
+ # Set seeds for reproducibility
29
+ torch.manual_seed(SEED)
30
+ np.random.seed(SEED)
31
 
32
+ class ConversationalAI:
33
+ def __init__(self):
34
+ print("πŸ”„ Initializing Conversational AI...")
35
+ self.setup_models()
36
+ print("βœ… All models loaded successfully!")
37
+
38
+ def setup_models(self):
39
+ """Initialize all models with T4 GPU optimization"""
40
+
41
+ # 1. ASR Model - Parakeet for high accuracy speech recognition
42
+ print("πŸ“’ Loading ASR model...")
43
  try:
44
+ self.asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(
45
+ model_name="nvidia/parakeet-tdt-0.6b-v2"
46
+ ).to(DEVICE)[7][9]
47
+ self.asr_model.eval()
48
+ print("βœ… ASR model loaded")
 
 
 
 
 
49
  except Exception as e:
50
+ print(f"⚠️ ASR fallback: {e}")
51
+ # Fallback to Whisper if Parakeet fails
52
+ self.asr_pipeline = pipeline(
53
+ "automatic-speech-recognition",
54
+ model="openai/whisper-base.en",
55
+ device=0 if DEVICE == "cuda" else -1
56
+ )[31]
57
+
58
+ # 2. LLM Model - Quantized Llama for T4 GPU compatibility
59
+ print("🧠 Loading LLM model...")
60
+ quantization_config = BitsAndBytesConfig(
61
+ load_in_4bit=True,
62
+ bnb_4bit_compute_dtype=torch.float16,
63
+ bnb_4bit_use_double_quant=True,
64
+ bnb_4bit_quant_type="nf4"
65
+ )[25][32]
66
+
67
+ model_name = "microsoft/DialoGPT-medium" # Optimized for conversation
68
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
69
+ self.tokenizer.pad_token = self.tokenizer.eos_token
70
+
71
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
72
+ model_name,
73
+ quantization_config=quantization_config,
74
+ device_map="auto",
75
+ torch_dtype=torch.float16,
76
+ low_cpu_mem_usage=True
77
+ )[42][44]
78
+
79
+ print("βœ… LLM model loaded")
80
+
81
+ # 3. TTS Model - Coqui TTS for female voice consistency
82
+ print("πŸ—£οΈ Loading TTS model...")
83
+ try:
84
+ # Using XTTS-v2 for high quality female voice
85
+ self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(DEVICE)[33][35]
86
 
87
+ # Create consistent female voice embedding
88
+ self.female_voice_path = self.create_female_reference()
89
+ print("βœ… TTS model loaded with female voice")
90
+ except Exception as e:
91
+ print(f"⚠️ TTS fallback: {e}")
92
+ # Fallback to simpler TTS model
93
+ self.tts = TTS("tts_models/en/ljspeech/tacotron2-DDC").to(DEVICE)[33]
94
+
95
+ # Memory optimization
96
+ if DEVICE == "cuda":
97
  torch.cuda.empty_cache()
98
+
99
+ def create_female_reference(self):
100
+ """Create a consistent female voice reference for TTS"""
101
+ # Generate a short reference audio with consistent female characteristics
102
+ reference_text = "Hello, I am your AI assistant with a consistent female voice."
103
+
104
+ # Create temporary reference file
105
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
106
+
107
+ try:
108
+ # Use a built-in female speaker if available
109
+ wav = self.tts.tts(
110
+ text=reference_text,
111
+ language="en",
112
+ split_sentences=True
113
+ )
114
 
115
+ # Save reference audio
116
+ sf.write(temp_file.name, wav, SAMPLE_RATE)
117
+ return temp_file.name
118
+ except:
119
+ return None
120
+
121
+ def transcribe_audio(self, audio_data):
122
+ """Convert speech to text using ASR"""
123
+ try:
124
+ if hasattr(self, 'asr_model'):
125
+ # Save audio temporarily for NeMo ASR
126
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
127
+ sf.write(temp_file.name, audio_data[1], audio_data[0])
128
+
129
+ # Transcribe
130
+ transcription = self.asr_model.transcribe([temp_file.name])[0]
131
+ os.unlink(temp_file.name)
132
+
133
+ return transcription.text if hasattr(transcription, 'text') else transcription
134
+ else:
135
+ # Use Whisper pipeline
136
+ return self.asr_pipeline({"sampling_rate": audio_data[0], "raw": audio_data[1]})["text"]
137
 
138
+ except Exception as e:
139
+ print(f"ASR Error: {e}")
140
+ return "Sorry, I couldn't understand the audio."
141
+
142
+ def generate_response(self, user_input, chat_history):
143
+ """Generate conversational response using LLM"""
144
+ try:
145
+ # Prepare conversation context
146
+ context = ""
147
+ for turn in chat_history[-3:]: # Last 3 turns for context
148
+ context += f"Human: {turn[0]}\nAssistant: {turn[1]}\n"
149
 
150
+ context += f"Human: {user_input}\nAssistant:"
 
 
 
151
 
152
+ # Tokenize and generate
153
+ inputs = self.tokenizer.encode(context, return_tensors="pt", max_length=512, truncation=True).to(DEVICE)
154
+
155
+ with torch.no_grad():
156
+ outputs = self.llm_model.generate(
157
+ inputs,
158
+ max_length=inputs.shape[1] + 100,
159
+ temperature=TEMPERATURE,
160
+ do_sample=True,
161
+ pad_token_id=self.tokenizer.eos_token_id,
162
+ no_repeat_ngram_size=2,
163
+ top_p=0.9
164
+ )
165
+
166
+ response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
167
+ response = response.split("Human:")[0].strip()
168
+
169
+ return response if response else "I understand. Please tell me more."
170
 
171
+ except Exception as e:
172
+ print(f"LLM Error: {e}")
173
+ return "I'm having trouble processing that. Could you please rephrase?"
174
+
175
+ def synthesize_speech(self, text):
176
+ """Convert text to speech with consistent female voice"""
177
+ try:
178
+ if self.female_voice_path and hasattr(self.tts, 'tts'):
179
+ # Use voice cloning for consistency
180
+ wav = self.tts.tts(
181
+ text=text,
182
+ speaker_wav=self.female_voice_path,
183
+ language="en",
184
+ split_sentences=True
185
+ )
186
+ else:
187
+ # Fallback to default synthesis
188
+ wav = self.tts.tts(text=text)
189
 
190
+ # Ensure proper format
191
+ if isinstance(wav, list):
192
+ wav = np.array(wav, dtype=np.float32)
 
 
193
 
194
+ # Normalize audio
195
+ wav = wav / np.max(np.abs(wav)) if np.max(np.abs(wav)) > 0 else wav
196
+
197
+ return (SAMPLE_RATE, (wav * 32767).astype(np.int16))
198
 
199
+ except Exception as e:
200
+ print(f"TTS Error: {e}")
201
+ # Return silence as fallback
202
+ return (SAMPLE_RATE, np.zeros(SAMPLE_RATE, dtype=np.int16))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ def process_conversation(self, audio_input, chat_history):
205
+ """Main pipeline: Speech -> Text -> LLM -> Speech"""
206
+ if audio_input is None:
207
+ return chat_history, None, ""
208
+
209
+ try:
210
+ # Step 1: Speech to Text
211
+ user_text = self.transcribe_audio(audio_input)
212
+ if not user_text.strip():
213
+ return chat_history, None, "No speech detected."
214
 
215
+ # Step 2: Generate Response
216
+ ai_response = self.generate_response(user_text, chat_history)
 
 
 
 
217
 
218
+ # Step 3: Text to Speech
219
+ audio_response = self.synthesize_speech(ai_response)
220
 
221
+ # Update chat history
222
+ chat_history.append([user_text, ai_response])
 
 
 
223
 
224
+ # Memory cleanup
225
+ if DEVICE == "cuda":
226
+ torch.cuda.empty_cache()
227
+ gc.collect()
228
+
229
+ return chat_history, audio_response, f"You said: {user_text}"
230
+
231
+ except Exception as e:
232
+ error_msg = f"Error processing conversation: {e}"
233
+ print(error_msg)
234
+ return chat_history, None, error_msg
235
 
236
+ # Initialize the AI system
237
+ print("πŸš€ Starting Conversational AI initialization...")
238
+ ai_system = ConversationalAI()
239
+
240
+ # Gradio Interface
241
+ def create_interface():
242
+ """Create the Gradio interface for the conversational AI"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
+ with gr.Blocks(
245
+ title="Advanced Conversational AI",
246
+ theme=gr.themes.Soft(),
247
+ css="""
248
+ .main-header { text-align: center; color: #2563eb; margin-bottom: 2rem; }
249
+ .chat-container { max-height: 500px; overflow-y: auto; }
250
+ .status-box { background: #f0f9ff; padding: 1rem; border-radius: 0.5rem; }
251
+ """
252
+ ) as demo:
253
+
254
+ gr.HTML("""
255
+ <div class="main-header">
256
+ <h1>πŸ€– Advanced Conversational AI</h1>
257
+ <p>Speak naturally and get intelligent responses with consistent female voice</p>
258
+ </div>
259
+ """)
260
+
261
+ with gr.Row():
262
+ with gr.Column(scale=2):
263
+ # Chat History
264
+ chatbot = gr.Chatbot(
265
+ label="Conversation History",
266
+ elem_classes=["chat-container"],
267
+ height=400,
268
+ show_copy_button=True
269
+ )
270
+
271
+ # Audio Input
272
+ audio_input = gr.Audio(
273
+ label="🎀 Speak to AI",
274
+ sources=["microphone"],
275
+ type="numpy",
276
+ format="wav"
277
+ )
278
+
279
+ # Control Buttons
280
+ with gr.Row():
281
+ submit_btn = gr.Button("πŸ’¬ Process Speech", variant="primary", scale=2)
282
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary", scale=1)
283
+
284
+ with gr.Column(scale=1):
285
+ # AI Response Audio
286
+ audio_output = gr.Audio(
287
+ label="πŸ”Š AI Response",
288
+ type="numpy",
289
+ autoplay=True
290
+ )
291
+
292
+ # Status Display
293
+ status_display = gr.Textbox(
294
+ label="πŸ“Š Status",
295
+ lines=3,
296
+ elem_classes=["status-box"],
297
+ interactive=False
298
+ )
299
+
300
+ # System Information
301
+ gr.HTML(f"""
302
+ <div class="status-box">
303
+ <h3>πŸ”§ System Info</h3>
304
+ <p><strong>Device:</strong> {DEVICE.upper()}</p>
305
+ <p><strong>Models:</strong> Parakeet ASR + DialoGPT + XTTS</p>
306
+ <p><strong>Voice:</strong> Consistent Female</p>
307
+ <p><strong>Memory:</strong> 4-bit Quantized</p>
308
+ </div>
309
+ """)
310
+
311
+ # Event Handlers
312
+ def process_audio(audio, history):
313
+ return ai_system.process_conversation(audio, history)
314
+
315
+ def clear_conversation():
316
+ if DEVICE == "cuda":
317
+ torch.cuda.empty_cache()
318
+ return [], None, "Conversation cleared."
319
+
320
+ # Button Events
321
+ submit_btn.click(
322
+ fn=process_audio,
323
+ inputs=[audio_input, chatbot],
324
+ outputs=[chatbot, audio_output, status_display],
325
+ show_progress=True
326
+ )
327
+
328
+ clear_btn.click(
329
+ fn=clear_conversation,
330
+ outputs=[chatbot, audio_output, status_display]
331
+ )
332
+
333
+ # Auto-process when audio is recorded
334
+ audio_input.change(
335
+ fn=process_audio,
336
+ inputs=[audio_input, chatbot],
337
+ outputs=[chatbot, audio_output, status_display]
338
+ )
339
+
340
+ # Example Usage
341
+ gr.HTML("""
342
+ <div style="margin-top: 2rem; padding: 1rem; background: #fef3c7; border-radius: 0.5rem;">
343
+ <h3>πŸ’‘ How to Use:</h3>
344
+ <ol>
345
+ <li>Click the microphone button and speak clearly</li>
346
+ <li>Wait for the AI to process your speech</li>
347
+ <li>Listen to the AI's response with consistent female voice</li>
348
+ <li>Continue the conversation naturally</li>
349
+ </ol>
350
+ </div>
351
+ """)
352
 
353
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
+ # Launch the application
356
  if __name__ == "__main__":
357
+ print("🌟 Creating Gradio interface...")
358
+ demo = create_interface()
359
+
360
+ print("πŸš€ Launching Conversational AI...")
361
+ demo.launch(
362
+ server_name="0.0.0.0",
363
+ server_port=7860,
364
+ share=True,
365
+ show_error=True,
366
+ debug=False
367
+ )