jsbeaudry commited on
Commit
28b8f48
·
verified ·
1 Parent(s): 2201e62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -17
app.py CHANGED
@@ -5,21 +5,30 @@ import numpy as np
5
  import soundfile as sf
6
  from transformers import (
7
  AutoProcessor, AutoModelForSpeechSeq2Seq,
8
- AutoModelForCausalLM, AutoTokenizer
 
9
  )
 
10
  import logging
11
  from typing import Optional, Dict, Any
12
  import time
13
  from pathlib import Path
14
 
15
  from kokoro import KPipeline
 
 
 
16
  import gradio as gr
 
 
17
 
18
 
19
  # Set up logging
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
 
 
23
  system_prompt_0 = """You are a highly trained U.S. Tax Assistant AI, designed to help individuals and small businesses understand, plan, and file their taxes according to federal and state tax laws. You explain complex tax concepts in simple, accurate, and actionable terms, using IRS guidelines, up-to-date tax code knowledge, and best practices for compliance and savings. You act as an explainer, educator, and assistant—not a certified tax preparer or legal advisor."""
24
 
25
 
@@ -30,19 +39,31 @@ class AsyncAIConversation:
30
  self.llm_tokenizer = None
31
  self.llm_model = None
32
  self.tts_synthesizer = None
 
33
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
  logger.info(f"Using device: {self.device}")
35
 
36
  async def initialize_models(self):
 
37
  logger.info("Initializing models...")
 
 
38
  await self._init_stt_model()
 
 
39
  await self._init_llm_model()
 
 
40
  await self._init_tts_model()
 
41
  logger.info("All models initialized successfully!")
42
 
43
  async def _init_stt_model(self):
 
 
44
  try:
45
  stt_model_id = "unsloth/whisper-small"
 
46
  self.stt_processor = AutoProcessor.from_pretrained(stt_model_id)
47
  self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(stt_model_id)
48
  self.stt_model.to(self.device)
@@ -52,8 +73,11 @@ class AsyncAIConversation:
52
  raise
53
 
54
  async def _init_llm_model(self):
 
 
55
  try:
56
  model_name = "unsloth/Qwen3-0.6B"
 
57
  self.llm_tokenizer = AutoTokenizer.from_pretrained(model_name)
58
  self.llm_model = AutoModelForCausalLM.from_pretrained(
59
  model_name,
@@ -66,7 +90,10 @@ class AsyncAIConversation:
66
  raise
67
 
68
  async def _init_tts_model(self):
 
 
69
  try:
 
70
  self.tts_synthesizer = KPipeline(lang_code='a')
71
  logger.info("TTS model loaded successfully")
72
  except Exception as e:
@@ -74,34 +101,56 @@ class AsyncAIConversation:
74
  raise
75
 
76
  async def speech_to_text(self, audio_file_path: str) -> str:
 
 
 
77
  try:
 
78
  def load_audio():
79
  return librosa.load(audio_file_path, sr=16000)
80
 
81
  loop = asyncio.get_event_loop()
82
  speech_array, sampling_rate = await loop.run_in_executor(None, load_audio)
83
 
 
 
 
 
84
  input_features = self.stt_processor(
85
  speech_array,
86
  sampling_rate=sampling_rate,
87
  return_tensors="pt"
88
  ).input_features.to(self.device)
89
 
 
90
  with torch.no_grad():
91
  predicted_ids = self.stt_model.generate(input_features)
92
 
 
93
  transcription = self.stt_processor.batch_decode(predicted_ids, skip_special_tokens=True)
94
- return transcription[0] if transcription else ""
 
 
 
 
95
  except Exception as e:
96
  logger.error(f"Error in speech_to_text: {e}")
97
  return ""
98
 
99
  async def process_with_llm(self, text: str, system_prompt: Optional[str] = None) -> Dict[str, str]:
 
 
 
100
  try:
101
- messages = [{"role": "user", "content": text}]
 
 
 
 
102
  if system_prompt:
103
  messages.insert(0, {"role": "system", "content": system_prompt})
104
 
 
105
  formatted_text = self.llm_tokenizer.apply_chat_template(
106
  messages,
107
  tokenize=False,
@@ -109,8 +158,10 @@ class AsyncAIConversation:
109
  enable_thinking=False
110
  )
111
 
 
112
  model_inputs = self.llm_tokenizer([formatted_text], return_tensors="pt").to(self.llm_model.device)
113
 
 
114
  with torch.no_grad():
115
  generated_ids = self.llm_model.generate(
116
  **model_inputs,
@@ -120,9 +171,12 @@ class AsyncAIConversation:
120
  pad_token_id=self.llm_tokenizer.eos_token_id
121
  )
122
 
 
123
  output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
124
 
 
125
  try:
 
126
  index = len(output_ids) - output_ids[::-1].index(151668)
127
  except ValueError:
128
  index = 0
@@ -130,21 +184,31 @@ class AsyncAIConversation:
130
  thinking_content = self.llm_tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
131
  content = self.llm_tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
132
 
133
- return {
134
  "thinking": thinking_content,
135
  "response": content
136
  }
137
 
 
 
 
138
  except Exception as e:
139
  logger.error(f"Error in process_with_llm: {e}")
140
  return {"thinking": "", "response": "Sorry, I encountered an error processing your request."}
141
 
142
  async def text_to_speech(self, text: str, output_path: str = "response.wav") -> str:
 
 
 
143
  try:
 
144
  def generate_speech():
 
145
  generator = self.tts_synthesizer(text, voice='af_heart')
 
 
146
  for i, (gs, ps, audio) in enumerate(generator):
147
- if i == 0:
148
  return audio
149
  return None
150
 
@@ -154,47 +218,105 @@ class AsyncAIConversation:
154
  if audio_data is None:
155
  raise ValueError("Failed to generate audio")
156
 
 
157
  sf.write(output_path, audio_data, samplerate=24000)
 
 
158
  return output_path
 
159
  except Exception as e:
160
  logger.error(f"Error in text_to_speech: {e}")
161
  return ""
162
 
163
  async def process_conversation(self, audio_file_path: str, system_prompt: Optional[str] = None) -> Dict[str, Any]:
 
 
 
 
164
  try:
 
 
165
  transcribed_text = await self.speech_to_text(audio_file_path)
 
 
166
  if not transcribed_text:
167
  return {"error": "Failed to transcribe audio"}
168
 
 
 
169
  llm_result = await self.process_with_llm(transcribed_text, system_prompt)
 
 
 
 
170
  audio_output_path = await self.text_to_speech(llm_result["response"])
 
171
 
172
- return {
 
 
173
  "input_audio": audio_file_path,
174
  "transcribed_text": transcribed_text,
175
  "thinking": llm_result["thinking"],
176
  "response_text": llm_result["response"],
177
  "output_audio": audio_output_path,
 
 
 
 
 
 
178
  }
 
 
 
 
179
  except Exception as e:
180
  logger.error(f"Error in process_conversation: {e}")
181
  return {"error": str(e)}
182
 
183
- # ---------------------------- GLOBAL CONVERSATION OBJECT ----------------------------
184
- ai_conversation = AsyncAIConversation()
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- # ---------------------------- DEMO INITIALIZATION ----------------------------
 
 
 
187
  async def demo_conversation():
 
 
 
188
  await ai_conversation.initialize_models()
189
 
190
- # ---------------------------- GRADIO WRAPPER ----------------------------
 
 
191
  async def process_audio_gradio(audio_file, system_prompt_input):
 
 
192
  if audio_file is None:
193
  return "Please upload an audio file.", "", "", None
194
 
 
 
 
 
195
  try:
196
  result = await ai_conversation.process_conversation(
197
- audio_file_path=audio_file,
198
  system_prompt=system_prompt_input
199
  )
200
 
@@ -205,12 +327,12 @@ async def process_audio_gradio(audio_file, system_prompt_input):
205
  f"Transcribed: {result['transcribed_text']}\nThinking: {result['thinking']}",
206
  result['response_text'],
207
  result['output_audio'],
208
- None
209
  )
210
  except Exception as e:
211
- return f"Unexpected error: {e}", "", "", None
212
 
213
- # ---------------------------- GRADIO INTERFACE ----------------------------
214
  with gr.Blocks() as demo:
215
  gr.Markdown("# Asynchronous AI Conversation System")
216
  gr.Markdown("Upload an audio file and provide a system prompt to get a response.")
@@ -227,16 +349,21 @@ with gr.Blocks() as demo:
227
  response_audio_output = gr.Audio(label="AI Response Audio", interactive=False)
228
  processing_times_output = gr.JSON(label="Processing Times")
229
 
 
230
  process_button.click(
231
  fn=process_audio_gradio,
232
  inputs=[audio_input, system_prompt_input],
233
  outputs=[status_output, response_text_output, response_audio_output, processing_times_output]
234
  )
235
 
236
- # ---------------------------- MAIN LAUNCH ----------------------------
237
  if __name__ == "__main__":
238
- def initiate():
 
239
  asyncio.run(demo_conversation())
240
 
241
  initiate()
242
- demo.launch(debug=False, share=True)
 
 
 
 
5
  import soundfile as sf
6
  from transformers import (
7
  AutoProcessor, AutoModelForSpeechSeq2Seq,
8
+ AutoModelForCausalLM, AutoTokenizer,
9
+ pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
10
  )
11
+ from datasets import load_dataset
12
  import logging
13
  from typing import Optional, Dict, Any
14
  import time
15
  from pathlib import Path
16
 
17
  from kokoro import KPipeline
18
+ from IPython.display import display, Audio
19
+
20
+
21
  import gradio as gr
22
+ import asyncio
23
+ import os
24
 
25
 
26
  # Set up logging
27
  logging.basicConfig(level=logging.INFO)
28
  logger = logging.getLogger(__name__)
29
 
30
+
31
+
32
  system_prompt_0 = """You are a highly trained U.S. Tax Assistant AI, designed to help individuals and small businesses understand, plan, and file their taxes according to federal and state tax laws. You explain complex tax concepts in simple, accurate, and actionable terms, using IRS guidelines, up-to-date tax code knowledge, and best practices for compliance and savings. You act as an explainer, educator, and assistant—not a certified tax preparer or legal advisor."""
33
 
34
 
 
39
  self.llm_tokenizer = None
40
  self.llm_model = None
41
  self.tts_synthesizer = None
42
+ self.speaker_embedding = None
43
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
  logger.info(f"Using device: {self.device}")
45
 
46
  async def initialize_models(self):
47
+ """Initialize all models asynchronously"""
48
  logger.info("Initializing models...")
49
+
50
+ # Initialize STT model
51
  await self._init_stt_model()
52
+
53
+ # Initialize LLM model
54
  await self._init_llm_model()
55
+
56
+ # Initialize TTS model
57
  await self._init_tts_model()
58
+
59
  logger.info("All models initialized successfully!")
60
 
61
  async def _init_stt_model(self):
62
+ """Initialize Speech-to-Text model"""
63
+ logger.info("Loading STT model...")
64
  try:
65
  stt_model_id = "unsloth/whisper-small"
66
+ #unsloth/whisper-large-v3-turbo
67
  self.stt_processor = AutoProcessor.from_pretrained(stt_model_id)
68
  self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(stt_model_id)
69
  self.stt_model.to(self.device)
 
73
  raise
74
 
75
  async def _init_llm_model(self):
76
+ """Initialize Large Language Model"""
77
+ logger.info("Loading LLM model...")
78
  try:
79
  model_name = "unsloth/Qwen3-0.6B"
80
+ #unsloth/Qwen3-0.6B-unsloth-bnb-4bit
81
  self.llm_tokenizer = AutoTokenizer.from_pretrained(model_name)
82
  self.llm_model = AutoModelForCausalLM.from_pretrained(
83
  model_name,
 
90
  raise
91
 
92
  async def _init_tts_model(self):
93
+ """Initialize Text-to-Speech model"""
94
+ logger.info("Loading TTS model...")
95
  try:
96
+ # Initialize Kokoro TTS pipeline
97
  self.tts_synthesizer = KPipeline(lang_code='a')
98
  logger.info("TTS model loaded successfully")
99
  except Exception as e:
 
101
  raise
102
 
103
  async def speech_to_text(self, audio_file_path: str) -> str:
104
+ """Convert speech to text asynchronously"""
105
+ logger.info(f"Processing audio file: {audio_file_path}")
106
+
107
  try:
108
+ # Load audio in a separate thread to avoid blocking
109
  def load_audio():
110
  return librosa.load(audio_file_path, sr=16000)
111
 
112
  loop = asyncio.get_event_loop()
113
  speech_array, sampling_rate = await loop.run_in_executor(None, load_audio)
114
 
115
+ # Convert to tensor
116
+ speech_array_pt = torch.from_numpy(speech_array).unsqueeze(0).to(self.device)
117
+
118
+ # Process input features
119
  input_features = self.stt_processor(
120
  speech_array,
121
  sampling_rate=sampling_rate,
122
  return_tensors="pt"
123
  ).input_features.to(self.device)
124
 
125
+ # Generate predictions
126
  with torch.no_grad():
127
  predicted_ids = self.stt_model.generate(input_features)
128
 
129
+ # Decode predictions
130
  transcription = self.stt_processor.batch_decode(predicted_ids, skip_special_tokens=True)
131
+
132
+ result = transcription[0] if transcription else ""
133
+ logger.info(f"STT result: {result}")
134
+ return result
135
+
136
  except Exception as e:
137
  logger.error(f"Error in speech_to_text: {e}")
138
  return ""
139
 
140
  async def process_with_llm(self, text: str, system_prompt: Optional[str] = None) -> Dict[str, str]:
141
+ """Process text with LLM and return both thinking and content"""
142
+ logger.info(f"Processing text with LLM: {text[:50]}...")
143
+
144
  try:
145
+ # Prepare messages
146
+ messages = [
147
+ {"role": "user", "content": text}
148
+ ]
149
+
150
  if system_prompt:
151
  messages.insert(0, {"role": "system", "content": system_prompt})
152
 
153
+ # Apply chat template
154
  formatted_text = self.llm_tokenizer.apply_chat_template(
155
  messages,
156
  tokenize=False,
 
158
  enable_thinking=False
159
  )
160
 
161
+ # Tokenize
162
  model_inputs = self.llm_tokenizer([formatted_text], return_tensors="pt").to(self.llm_model.device)
163
 
164
+ # Generate response
165
  with torch.no_grad():
166
  generated_ids = self.llm_model.generate(
167
  **model_inputs,
 
171
  pad_token_id=self.llm_tokenizer.eos_token_id
172
  )
173
 
174
+ # Extract new tokens
175
  output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
176
 
177
+ # Parse thinking content
178
  try:
179
+ # Find the end of thinking token (</think>)
180
  index = len(output_ids) - output_ids[::-1].index(151668)
181
  except ValueError:
182
  index = 0
 
184
  thinking_content = self.llm_tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
185
  content = self.llm_tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
186
 
187
+ result = {
188
  "thinking": thinking_content,
189
  "response": content
190
  }
191
 
192
+ logger.info(f"LLM response generated: {content[:50]}...")
193
+ return result
194
+
195
  except Exception as e:
196
  logger.error(f"Error in process_with_llm: {e}")
197
  return {"thinking": "", "response": "Sorry, I encountered an error processing your request."}
198
 
199
  async def text_to_speech(self, text: str, output_path: str = "response.wav") -> str:
200
+ """Convert text to speech asynchronously"""
201
+ logger.info(f"Converting text to speech: {text[:50]}...")
202
+
203
  try:
204
+ # Generate speech in a separate thread to avoid blocking
205
  def generate_speech():
206
+ # Generate audio using Kokoro TTS
207
  generator = self.tts_synthesizer(text, voice='af_heart')
208
+
209
+ # Get the first generated audio chunk
210
  for i, (gs, ps, audio) in enumerate(generator):
211
+ if i == 0: # Use the first chunk
212
  return audio
213
  return None
214
 
 
218
  if audio_data is None:
219
  raise ValueError("Failed to generate audio")
220
 
221
+ # Save audio file with Kokoro's default sample rate (24000 Hz)
222
  sf.write(output_path, audio_data, samplerate=24000)
223
+
224
+ logger.info(f"Audio saved to: {output_path}")
225
  return output_path
226
+
227
  except Exception as e:
228
  logger.error(f"Error in text_to_speech: {e}")
229
  return ""
230
 
231
  async def process_conversation(self, audio_file_path: str, system_prompt: Optional[str] = None) -> Dict[str, Any]:
232
+ """Complete conversation pipeline: STT -> LLM -> TTS"""
233
+ start_time = time.time()
234
+ logger.info("Starting conversation processing...")
235
+
236
  try:
237
+ # Step 1: Speech to Text
238
+ stt_start = time.time()
239
  transcribed_text = await self.speech_to_text(audio_file_path)
240
+ stt_time = time.time() - stt_start
241
+
242
  if not transcribed_text:
243
  return {"error": "Failed to transcribe audio"}
244
 
245
+ # Step 2: Process with LLM
246
+ llm_start = time.time()
247
  llm_result = await self.process_with_llm(transcribed_text, system_prompt)
248
+ llm_time = time.time() - llm_start
249
+
250
+ # Step 3: Text to Speech
251
+ tts_start = time.time()
252
  audio_output_path = await self.text_to_speech(llm_result["response"])
253
+ tts_time = time.time() - tts_start
254
 
255
+ total_time = time.time() - start_time
256
+
257
+ result = {
258
  "input_audio": audio_file_path,
259
  "transcribed_text": transcribed_text,
260
  "thinking": llm_result["thinking"],
261
  "response_text": llm_result["response"],
262
  "output_audio": audio_output_path,
263
+ "processing_times": {
264
+ "stt": stt_time,
265
+ "llm": llm_time,
266
+ "tts": tts_time,
267
+ "total": total_time
268
+ }
269
  }
270
+
271
+ logger.info(f"Conversation processed successfully in {total_time:.2f} seconds")
272
+ return result
273
+
274
  except Exception as e:
275
  logger.error(f"Error in process_conversation: {e}")
276
  return {"error": str(e)}
277
 
278
+ async def batch_process(self, audio_files: list, system_prompt: Optional[str] = None) -> list:
279
+ """Process multiple audio files concurrently"""
280
+ logger.info(f"Processing {len(audio_files)} audio files...")
281
+
282
+ # Create tasks for concurrent processing
283
+ tasks = [
284
+ self.process_conversation(audio_file, system_prompt)
285
+ for audio_file in audio_files
286
+ ]
287
+
288
+ # Process all files concurrently
289
+ results = await asyncio.gather(*tasks, return_exceptions=True)
290
+
291
+ logger.info("Batch processing completed")
292
+ return results
293
 
294
+ # Initialize the conversation system
295
+ ai_conversation = AsyncAIConversation()
296
+
297
+ # Usage example and demo functions
298
  async def demo_conversation():
299
+ """Demonstration of the conversation system"""
300
+
301
+ # Initialize all models
302
  await ai_conversation.initialize_models()
303
 
304
+
305
+
306
+ # Create the async function wrapper for Gradio
307
  async def process_audio_gradio(audio_file, system_prompt_input):
308
+
309
+ """Processes audio file and system prompt for Gradio interface."""
310
  if audio_file is None:
311
  return "Please upload an audio file.", "", "", None
312
 
313
+ # Gradio provides the file path
314
+ audio_path = audio_file
315
+
316
+ # Process the conversation using the initialized ai_conversation instance
317
  try:
318
  result = await ai_conversation.process_conversation(
319
+ audio_file_path=audio_path,
320
  system_prompt=system_prompt_input
321
  )
322
 
 
327
  f"Transcribed: {result['transcribed_text']}\nThinking: {result['thinking']}",
328
  result['response_text'],
329
  result['output_audio'],
330
+ result['processing_times']
331
  )
332
  except Exception as e:
333
+ return f"An unexpected error occurred: {e}", "", "", None
334
 
335
+ # Define the Gradio interface
336
  with gr.Blocks() as demo:
337
  gr.Markdown("# Asynchronous AI Conversation System")
338
  gr.Markdown("Upload an audio file and provide a system prompt to get a response.")
 
349
  response_audio_output = gr.Audio(label="AI Response Audio", interactive=False)
350
  processing_times_output = gr.JSON(label="Processing Times")
351
 
352
+ # Link button click to the async function
353
  process_button.click(
354
  fn=process_audio_gradio,
355
  inputs=[audio_input, system_prompt_input],
356
  outputs=[status_output, response_text_output, response_audio_output, processing_times_output]
357
  )
358
 
359
+
360
  if __name__ == "__main__":
361
+
362
+ def initiate():
363
  asyncio.run(demo_conversation())
364
 
365
  initiate()
366
+
367
+ # Gradio launch itself runs an event loop.
368
+ # Ensure ai_conversation is initialized in the notebook before this cell is run.
369
+ demo.launch(debug=False, share=True)