jsbeaudry commited on
Commit
35e66cc
·
verified ·
1 Parent(s): 97306da

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +422 -0
handler.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import torch
3
+ import librosa
4
+ import numpy as np
5
+ import soundfile as sf
6
+ from transformers import (
7
+ AutoProcessor, AutoModelForSpeechSeq2Seq,
8
+ AutoModelForCausalLM, AutoTokenizer,
9
+ pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
10
+ )
11
+ from datasets import load_dataset
12
+ import logging
13
+ from typing import Optional, Dict, Any
14
+ import time
15
+ from pathlib import Path
16
+
17
+ from kokoro import KPipeline
18
+ from IPython.display import display, Audio
19
+
20
+
21
+ import gradio as gr
22
+ import asyncio
23
+ import os
24
+
25
+
26
+ # Set up logging
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+ class AsyncAIConversation:
31
+ def __init__(self):
32
+ self.stt_processor = None
33
+ self.stt_model = None
34
+ self.llm_tokenizer = None
35
+ self.llm_model = None
36
+ self.tts_synthesizer = None
37
+ self.speaker_embedding = None
38
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ logger.info(f"Using device: {self.device}")
40
+
41
+ async def initialize_models(self):
42
+ """Initialize all models asynchronously"""
43
+ logger.info("Initializing models...")
44
+
45
+ # Initialize STT model
46
+ await self._init_stt_model()
47
+
48
+ # Initialize LLM model
49
+ await self._init_llm_model()
50
+
51
+ # Initialize TTS model
52
+ await self._init_tts_model()
53
+
54
+ logger.info("All models initialized successfully!")
55
+
56
+ async def _init_stt_model(self):
57
+ """Initialize Speech-to-Text model"""
58
+ logger.info("Loading STT model...")
59
+ try:
60
+ stt_model_id = "unsloth/whisper-small"
61
+ #unsloth/whisper-large-v3-turbo
62
+ self.stt_processor = AutoProcessor.from_pretrained(stt_model_id)
63
+ self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(stt_model_id)
64
+ self.stt_model.to(self.device)
65
+ logger.info("STT model loaded successfully")
66
+ except Exception as e:
67
+ logger.error(f"Error loading STT model: {e}")
68
+ raise
69
+
70
+ async def _init_llm_model(self):
71
+ """Initialize Large Language Model"""
72
+ logger.info("Loading LLM model...")
73
+ try:
74
+ model_name = "unsloth/Qwen3-0.6B"
75
+ #unsloth/Qwen3-0.6B-unsloth-bnb-4bit
76
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(model_name)
77
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
78
+ model_name,
79
+ torch_dtype="auto",
80
+ device_map="auto"
81
+ )
82
+ logger.info("LLM model loaded successfully")
83
+ except Exception as e:
84
+ logger.error(f"Error loading LLM model: {e}")
85
+ raise
86
+
87
+ async def _init_tts_model(self):
88
+ """Initialize Text-to-Speech model"""
89
+ logger.info("Loading TTS model...")
90
+ try:
91
+ # Initialize Kokoro TTS pipeline
92
+ self.tts_synthesizer = KPipeline(lang_code='a')
93
+ logger.info("TTS model loaded successfully")
94
+ except Exception as e:
95
+ logger.error(f"Error loading TTS model: {e}")
96
+ raise
97
+
98
+ async def speech_to_text(self, audio_file_path: str) -> str:
99
+ """Convert speech to text asynchronously"""
100
+ logger.info(f"Processing audio file: {audio_file_path}")
101
+
102
+ try:
103
+ # Load audio in a separate thread to avoid blocking
104
+ def load_audio():
105
+ return librosa.load(audio_file_path, sr=16000)
106
+
107
+ loop = asyncio.get_event_loop()
108
+ speech_array, sampling_rate = await loop.run_in_executor(None, load_audio)
109
+
110
+ # Convert to tensor
111
+ speech_array_pt = torch.from_numpy(speech_array).unsqueeze(0).to(self.device)
112
+
113
+ # Process input features
114
+ input_features = self.stt_processor(
115
+ speech_array,
116
+ sampling_rate=sampling_rate,
117
+ return_tensors="pt"
118
+ ).input_features.to(self.device)
119
+
120
+ # Generate predictions
121
+ with torch.no_grad():
122
+ predicted_ids = self.stt_model.generate(input_features)
123
+
124
+ # Decode predictions
125
+ transcription = self.stt_processor.batch_decode(predicted_ids, skip_special_tokens=True)
126
+
127
+ result = transcription[0] if transcription else ""
128
+ logger.info(f"STT result: {result}")
129
+ return result
130
+
131
+ except Exception as e:
132
+ logger.error(f"Error in speech_to_text: {e}")
133
+ return ""
134
+
135
+ async def process_with_llm(self, text: str, system_prompt: Optional[str] = None) -> Dict[str, str]:
136
+ """Process text with LLM and return both thinking and content"""
137
+ logger.info(f"Processing text with LLM: {text[:50]}...")
138
+
139
+ try:
140
+ # Prepare messages
141
+ messages = [
142
+ {"role": "user", "content": text}
143
+ ]
144
+
145
+ if system_prompt:
146
+ messages.insert(0, {"role": "system", "content": system_prompt})
147
+
148
+ # Apply chat template
149
+ formatted_text = self.llm_tokenizer.apply_chat_template(
150
+ messages,
151
+ tokenize=False,
152
+ add_generation_prompt=True,
153
+ enable_thinking=False
154
+ )
155
+
156
+ # Tokenize
157
+ model_inputs = self.llm_tokenizer([formatted_text], return_tensors="pt").to(self.llm_model.device)
158
+
159
+ # Generate response
160
+ with torch.no_grad():
161
+ generated_ids = self.llm_model.generate(
162
+ **model_inputs,
163
+ max_new_tokens=512,
164
+ temperature=0.7,
165
+ do_sample=True,
166
+ pad_token_id=self.llm_tokenizer.eos_token_id
167
+ )
168
+
169
+ # Extract new tokens
170
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
171
+
172
+ # Parse thinking content
173
+ try:
174
+ # Find the end of thinking token (</think>)
175
+ index = len(output_ids) - output_ids[::-1].index(151668)
176
+ except ValueError:
177
+ index = 0
178
+
179
+ thinking_content = self.llm_tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
180
+ content = self.llm_tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
181
+
182
+ result = {
183
+ "thinking": thinking_content,
184
+ "response": content
185
+ }
186
+
187
+ logger.info(f"LLM response generated: {content[:50]}...")
188
+ return result
189
+
190
+ except Exception as e:
191
+ logger.error(f"Error in process_with_llm: {e}")
192
+ return {"thinking": "", "response": "Sorry, I encountered an error processing your request."}
193
+
194
+ async def text_to_speech(self, text: str, output_path: str = "response.wav") -> str:
195
+ """Convert text to speech asynchronously"""
196
+ logger.info(f"Converting text to speech: {text[:50]}...")
197
+
198
+ try:
199
+ # Generate speech in a separate thread to avoid blocking
200
+ def generate_speech():
201
+ # Generate audio using Kokoro TTS
202
+ generator = self.tts_synthesizer(text, voice='af_heart')
203
+
204
+ # Get the first generated audio chunk
205
+ for i, (gs, ps, audio) in enumerate(generator):
206
+ if i == 0: # Use the first chunk
207
+ return audio
208
+ return None
209
+
210
+ loop = asyncio.get_event_loop()
211
+ audio_data = await loop.run_in_executor(None, generate_speech)
212
+
213
+ if audio_data is None:
214
+ raise ValueError("Failed to generate audio")
215
+
216
+ # Save audio file with Kokoro's default sample rate (24000 Hz)
217
+ sf.write(output_path, audio_data, samplerate=24000)
218
+
219
+ logger.info(f"Audio saved to: {output_path}")
220
+ return output_path
221
+
222
+ except Exception as e:
223
+ logger.error(f"Error in text_to_speech: {e}")
224
+ return ""
225
+
226
+ async def process_conversation(self, audio_file_path: str, system_prompt: Optional[str] = None) -> Dict[str, Any]:
227
+ """Complete conversation pipeline: STT -> LLM -> TTS"""
228
+ start_time = time.time()
229
+ logger.info("Starting conversation processing...")
230
+
231
+ try:
232
+ # Step 1: Speech to Text
233
+ stt_start = time.time()
234
+ transcribed_text = await self.speech_to_text(audio_file_path)
235
+ stt_time = time.time() - stt_start
236
+
237
+ if not transcribed_text:
238
+ return {"error": "Failed to transcribe audio"}
239
+
240
+ # Step 2: Process with LLM
241
+ llm_start = time.time()
242
+ llm_result = await self.process_with_llm(transcribed_text, system_prompt)
243
+ llm_time = time.time() - llm_start
244
+
245
+ # Step 3: Text to Speech
246
+ tts_start = time.time()
247
+ audio_output_path = await self.text_to_speech(llm_result["response"])
248
+ tts_time = time.time() - tts_start
249
+
250
+ total_time = time.time() - start_time
251
+
252
+ result = {
253
+ "input_audio": audio_file_path,
254
+ "transcribed_text": transcribed_text,
255
+ "thinking": llm_result["thinking"],
256
+ "response_text": llm_result["response"],
257
+ "output_audio": audio_output_path,
258
+ "processing_times": {
259
+ "stt": stt_time,
260
+ "llm": llm_time,
261
+ "tts": tts_time,
262
+ "total": total_time
263
+ }
264
+ }
265
+
266
+ logger.info(f"Conversation processed successfully in {total_time:.2f} seconds")
267
+ return result
268
+
269
+ except Exception as e:
270
+ logger.error(f"Error in process_conversation: {e}")
271
+ return {"error": str(e)}
272
+
273
+ async def batch_process(self, audio_files: list, system_prompt: Optional[str] = None) -> list:
274
+ """Process multiple audio files concurrently"""
275
+ logger.info(f"Processing {len(audio_files)} audio files...")
276
+
277
+ # Create tasks for concurrent processing
278
+ tasks = [
279
+ self.process_conversation(audio_file, system_prompt)
280
+ for audio_file in audio_files
281
+ ]
282
+
283
+ # Process all files concurrently
284
+ results = await asyncio.gather(*tasks, return_exceptions=True)
285
+
286
+ logger.info("Batch processing completed")
287
+ return results
288
+
289
+ # Usage example and demo functions
290
+ # async def demo_conversation():
291
+ # """Demonstration of the conversation system"""
292
+ # # Initialize the conversation system
293
+ # ai_conversation = AsyncAIConversation()
294
+
295
+ # # Initialize all models
296
+ # await ai_conversation.initialize_models()
297
+
298
+ # # Example usage
299
+ # audio_file = "/content/Recording 2.wav" # Replace with your audio file path
300
+ # system_prompt = "You are a helpful assistant. Please provide clear and concise responses."
301
+
302
+ # # Process the conversation
303
+ # result = await ai_conversation.process_conversation(audio_file, system_prompt)
304
+
305
+ # if "error" in result:
306
+ # print(f"Error: {result['error']}")
307
+ # else:
308
+ # print(f"Transcribed: {result['transcribed_text']}")
309
+ # print(f"Thinking: {result['thinking']}")
310
+ # print(f"Response: {result['response_text']}")
311
+ # print(f"Audio saved to: {result['output_audio']}")
312
+ # print(f"Processing times: {result['processing_times']}")
313
+
314
+ # async def demo_batch_processing():
315
+ # """Demonstration of batch processing"""
316
+ # ai_conversation = AsyncAIConversation()
317
+ # await ai_conversation.initialize_models()
318
+
319
+ # # Example batch processing
320
+ # audio_files = [
321
+ # "/content/Recording 1.wav",
322
+ # "/content/Recording 2.wav",
323
+ # "/content/Recording 3.wav"
324
+ # ]
325
+
326
+ # results = await ai_conversation.batch_process(audio_files)
327
+
328
+ # for i, result in enumerate(results):
329
+ # print(f"File {i+1}: {result}")
330
+
331
+ # Additional utility function for testing Kokoro TTS standalone
332
+ # async def test_kokoro_tts():
333
+ # """Test Kokoro TTS functionality standalone"""
334
+ # try:
335
+ # tts_synthesizer = KPipeline(lang_code='a')
336
+
337
+ # test_text = "Hello, this is a test of the Kokoro text-to-speech system."
338
+
339
+ # # Generate audio
340
+ # generator = tts_synthesizer(test_text, voice='af_heart')
341
+
342
+ # for i, (gs, ps, audio) in enumerate(generator):
343
+ # output_path = f"kokoro_test_{i}.wav"
344
+ # sf.write(output_path, audio, 24000)
345
+ # print(f"Test audio {i} saved to: {output_path}")
346
+
347
+ # # Only process first chunk for testing
348
+ # if i == 0:
349
+ # break
350
+
351
+ # except Exception as e:
352
+ # print(f"Error testing Kokoro TTS: {e}")
353
+
354
+
355
+
356
+
357
+ # Create the async function wrapper for Gradio
358
+ async def process_audio_gradio(audio_file, system_prompt_input):
359
+ """Processes audio file and system prompt for Gradio interface."""
360
+ if audio_file is None:
361
+ return "Please upload an audio file.", "", "", None
362
+
363
+ # Gradio provides the file path
364
+ audio_path = audio_file
365
+
366
+ # Process the conversation using the initialized ai_conversation instance
367
+ try:
368
+ result = await ai_conversation.process_conversation(
369
+ audio_file_path=audio_path,
370
+ system_prompt=system_prompt_input
371
+ )
372
+
373
+ if "error" in result:
374
+ return f"Error: {result['error']}", "", "", None
375
+ else:
376
+ return (
377
+ f"Transcribed: {result['transcribed_text']}\nThinking: {result['thinking']}",
378
+ result['response_text'],
379
+ result['output_audio'],
380
+ result['processing_times']
381
+ )
382
+ except Exception as e:
383
+ return f"An unexpected error occurred: {e}", "", "", None
384
+
385
+ # Define the Gradio interface
386
+ with gr.Blocks() as demo:
387
+ gr.Markdown("# Asynchronous AI Conversation System")
388
+ gr.Markdown("Upload an audio file and provide a system prompt to get a response.")
389
+
390
+ with gr.Row():
391
+ audio_input = gr.Audio(label="Upload Audio File", type="filepath")
392
+ system_prompt_input = gr.Textbox(label="System Prompt", value=system_prompt_0)
393
+
394
+ process_button = gr.Button("Process Conversation")
395
+
396
+ with gr.Column():
397
+ status_output = gr.Textbox(label="Status/Transcription/Thinking", interactive=False)
398
+ response_text_output = gr.Textbox(label="AI Response Text", interactive=False)
399
+ response_audio_output = gr.Audio(label="AI Response Audio", interactive=False)
400
+ processing_times_output = gr.JSON(label="Processing Times")
401
+
402
+ # Link button click to the async function
403
+ process_button.click(
404
+ fn=process_audio_gradio,
405
+ inputs=[audio_input, system_prompt_input],
406
+ outputs=[status_output, response_text_output, response_audio_output, processing_times_output]
407
+ )
408
+
409
+ # Launch the Gradio interface
410
+ # We need to run the Gradio app within an async context if we're using await inside the handler.
411
+ # However, Gradio's launch already handles the async loop for the button clicks.
412
+ # The key is that ai_conversation.initialize_models() must be awaited *before* launching Gradio.
413
+
414
+ # Since the notebook already executed the initialization:
415
+ # ai_conversation = AsyncAIConversation()
416
+ # await ai_conversation.initialize_models()
417
+ # We can directly launch the demo.
418
+
419
+ if __name__ == "__main__":
420
+ # Gradio launch itself runs an event loop.
421
+ # Ensure ai_conversation is initialized in the notebook before this cell is run.
422
+ demo.launch(debug=False, share=True)