Spaces:
Sleeping
Sleeping
import asyncio | |
import torch | |
import librosa | |
import numpy as np | |
import soundfile as sf | |
from transformers import ( | |
AutoProcessor, AutoModelForSpeechSeq2Seq, | |
AutoModelForCausalLM, AutoTokenizer, | |
pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
) | |
from datasets import load_dataset | |
import logging | |
from typing import Optional, Dict, Any | |
import time | |
from pathlib import Path | |
from kokoro import KPipeline | |
from IPython.display import display, Audio | |
import gradio as gr | |
import asyncio | |
import os | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class AsyncAIConversation: | |
def __init__(self): | |
self.stt_processor = None | |
self.stt_model = None | |
self.llm_tokenizer = None | |
self.llm_model = None | |
self.tts_synthesizer = None | |
self.speaker_embedding = None | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Using device: {self.device}") | |
async def initialize_models(self): | |
"""Initialize all models asynchronously""" | |
logger.info("Initializing models...") | |
# Initialize STT model | |
await self._init_stt_model() | |
# Initialize LLM model | |
await self._init_llm_model() | |
# Initialize TTS model | |
await self._init_tts_model() | |
logger.info("All models initialized successfully!") | |
async def _init_stt_model(self): | |
"""Initialize Speech-to-Text model""" | |
logger.info("Loading STT model...") | |
try: | |
stt_model_id = "unsloth/whisper-small" | |
#unsloth/whisper-large-v3-turbo | |
self.stt_processor = AutoProcessor.from_pretrained(stt_model_id) | |
self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(stt_model_id) | |
self.stt_model.to(self.device) | |
logger.info("STT model loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading STT model: {e}") | |
raise | |
async def _init_llm_model(self): | |
"""Initialize Large Language Model""" | |
logger.info("Loading LLM model...") | |
try: | |
model_name = "unsloth/Qwen3-0.6B" | |
#unsloth/Qwen3-0.6B-unsloth-bnb-4bit | |
self.llm_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.llm_model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
logger.info("LLM model loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading LLM model: {e}") | |
raise | |
async def _init_tts_model(self): | |
"""Initialize Text-to-Speech model""" | |
logger.info("Loading TTS model...") | |
try: | |
# Initialize Kokoro TTS pipeline | |
self.tts_synthesizer = KPipeline(lang_code='a') | |
logger.info("TTS model loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading TTS model: {e}") | |
raise | |
async def speech_to_text(self, audio_file_path: str) -> str: | |
"""Convert speech to text asynchronously""" | |
logger.info(f"Processing audio file: {audio_file_path}") | |
try: | |
# Load audio in a separate thread to avoid blocking | |
def load_audio(): | |
return librosa.load(audio_file_path, sr=16000) | |
loop = asyncio.get_event_loop() | |
speech_array, sampling_rate = await loop.run_in_executor(None, load_audio) | |
# Convert to tensor | |
speech_array_pt = torch.from_numpy(speech_array).unsqueeze(0).to(self.device) | |
# Process input features | |
input_features = self.stt_processor( | |
speech_array, | |
sampling_rate=sampling_rate, | |
return_tensors="pt" | |
).input_features.to(self.device) | |
# Generate predictions | |
with torch.no_grad(): | |
predicted_ids = self.stt_model.generate(input_features) | |
# Decode predictions | |
transcription = self.stt_processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
result = transcription[0] if transcription else "" | |
logger.info(f"STT result: {result}") | |
return result | |
except Exception as e: | |
logger.error(f"Error in speech_to_text: {e}") | |
return "" | |
async def process_with_llm(self, text: str, system_prompt: Optional[str] = None) -> Dict[str, str]: | |
"""Process text with LLM and return both thinking and content""" | |
logger.info(f"Processing text with LLM: {text[:50]}...") | |
try: | |
# Prepare messages | |
messages = [ | |
{"role": "user", "content": text} | |
] | |
if system_prompt: | |
messages.insert(0, {"role": "system", "content": system_prompt}) | |
# Apply chat template | |
formatted_text = self.llm_tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True, | |
enable_thinking=False | |
) | |
# Tokenize | |
model_inputs = self.llm_tokenizer([formatted_text], return_tensors="pt").to(self.llm_model.device) | |
# Generate response | |
with torch.no_grad(): | |
generated_ids = self.llm_model.generate( | |
**model_inputs, | |
max_new_tokens=512, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=self.llm_tokenizer.eos_token_id | |
) | |
# Extract new tokens | |
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() | |
# Parse thinking content | |
try: | |
# Find the end of thinking token (</think>) | |
index = len(output_ids) - output_ids[::-1].index(151668) | |
except ValueError: | |
index = 0 | |
thinking_content = self.llm_tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n") | |
content = self.llm_tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n") | |
result = { | |
"thinking": thinking_content, | |
"response": content | |
} | |
logger.info(f"LLM response generated: {content[:50]}...") | |
return result | |
except Exception as e: | |
logger.error(f"Error in process_with_llm: {e}") | |
return {"thinking": "", "response": "Sorry, I encountered an error processing your request."} | |
async def text_to_speech(self, text: str, output_path: str = "response.wav") -> str: | |
"""Convert text to speech asynchronously""" | |
logger.info(f"Converting text to speech: {text[:50]}...") | |
try: | |
# Generate speech in a separate thread to avoid blocking | |
def generate_speech(): | |
# Generate audio using Kokoro TTS | |
generator = self.tts_synthesizer(text, voice='af_heart') | |
# Get the first generated audio chunk | |
for i, (gs, ps, audio) in enumerate(generator): | |
if i == 0: # Use the first chunk | |
return audio | |
return None | |
loop = asyncio.get_event_loop() | |
audio_data = await loop.run_in_executor(None, generate_speech) | |
if audio_data is None: | |
raise ValueError("Failed to generate audio") | |
# Save audio file with Kokoro's default sample rate (24000 Hz) | |
sf.write(output_path, audio_data, samplerate=24000) | |
logger.info(f"Audio saved to: {output_path}") | |
return output_path | |
except Exception as e: | |
logger.error(f"Error in text_to_speech: {e}") | |
return "" | |
async def process_conversation(self, audio_file_path: str, system_prompt: Optional[str] = None) -> Dict[str, Any]: | |
"""Complete conversation pipeline: STT -> LLM -> TTS""" | |
start_time = time.time() | |
logger.info("Starting conversation processing...") | |
try: | |
# Step 1: Speech to Text | |
stt_start = time.time() | |
transcribed_text = await self.speech_to_text(audio_file_path) | |
stt_time = time.time() - stt_start | |
if not transcribed_text: | |
return {"error": "Failed to transcribe audio"} | |
# Step 2: Process with LLM | |
llm_start = time.time() | |
llm_result = await self.process_with_llm(transcribed_text, system_prompt) | |
llm_time = time.time() - llm_start | |
# Step 3: Text to Speech | |
tts_start = time.time() | |
audio_output_path = await self.text_to_speech(llm_result["response"]) | |
tts_time = time.time() - tts_start | |
total_time = time.time() - start_time | |
result = { | |
"input_audio": audio_file_path, | |
"transcribed_text": transcribed_text, | |
"thinking": llm_result["thinking"], | |
"response_text": llm_result["response"], | |
"output_audio": audio_output_path, | |
"processing_times": { | |
"stt": stt_time, | |
"llm": llm_time, | |
"tts": tts_time, | |
"total": total_time | |
} | |
} | |
logger.info(f"Conversation processed successfully in {total_time:.2f} seconds") | |
return result | |
except Exception as e: | |
logger.error(f"Error in process_conversation: {e}") | |
return {"error": str(e)} | |
async def batch_process(self, audio_files: list, system_prompt: Optional[str] = None) -> list: | |
"""Process multiple audio files concurrently""" | |
logger.info(f"Processing {len(audio_files)} audio files...") | |
# Create tasks for concurrent processing | |
tasks = [ | |
self.process_conversation(audio_file, system_prompt) | |
for audio_file in audio_files | |
] | |
# Process all files concurrently | |
results = await asyncio.gather(*tasks, return_exceptions=True) | |
logger.info("Batch processing completed") | |
return results | |
# Usage example and demo functions | |
# async def demo_conversation(): | |
# """Demonstration of the conversation system""" | |
# # Initialize the conversation system | |
# ai_conversation = AsyncAIConversation() | |
# # Initialize all models | |
# await ai_conversation.initialize_models() | |
# # Example usage | |
# audio_file = "/content/Recording 2.wav" # Replace with your audio file path | |
# system_prompt = "You are a helpful assistant. Please provide clear and concise responses." | |
# # Process the conversation | |
# result = await ai_conversation.process_conversation(audio_file, system_prompt) | |
# if "error" in result: | |
# print(f"Error: {result['error']}") | |
# else: | |
# print(f"Transcribed: {result['transcribed_text']}") | |
# print(f"Thinking: {result['thinking']}") | |
# print(f"Response: {result['response_text']}") | |
# print(f"Audio saved to: {result['output_audio']}") | |
# print(f"Processing times: {result['processing_times']}") | |
# async def demo_batch_processing(): | |
# """Demonstration of batch processing""" | |
# ai_conversation = AsyncAIConversation() | |
# await ai_conversation.initialize_models() | |
# # Example batch processing | |
# audio_files = [ | |
# "/content/Recording 1.wav", | |
# "/content/Recording 2.wav", | |
# "/content/Recording 3.wav" | |
# ] | |
# results = await ai_conversation.batch_process(audio_files) | |
# for i, result in enumerate(results): | |
# print(f"File {i+1}: {result}") | |
# Additional utility function for testing Kokoro TTS standalone | |
# async def test_kokoro_tts(): | |
# """Test Kokoro TTS functionality standalone""" | |
# try: | |
# tts_synthesizer = KPipeline(lang_code='a') | |
# test_text = "Hello, this is a test of the Kokoro text-to-speech system." | |
# # Generate audio | |
# generator = tts_synthesizer(test_text, voice='af_heart') | |
# for i, (gs, ps, audio) in enumerate(generator): | |
# output_path = f"kokoro_test_{i}.wav" | |
# sf.write(output_path, audio, 24000) | |
# print(f"Test audio {i} saved to: {output_path}") | |
# # Only process first chunk for testing | |
# if i == 0: | |
# break | |
# except Exception as e: | |
# print(f"Error testing Kokoro TTS: {e}") | |
# Create the async function wrapper for Gradio | |
async def process_audio_gradio(audio_file, system_prompt_input): | |
"""Processes audio file and system prompt for Gradio interface.""" | |
if audio_file is None: | |
return "Please upload an audio file.", "", "", None | |
# Gradio provides the file path | |
audio_path = audio_file | |
# Process the conversation using the initialized ai_conversation instance | |
try: | |
result = await ai_conversation.process_conversation( | |
audio_file_path=audio_path, | |
system_prompt=system_prompt_input | |
) | |
if "error" in result: | |
return f"Error: {result['error']}", "", "", None | |
else: | |
return ( | |
f"Transcribed: {result['transcribed_text']}\nThinking: {result['thinking']}", | |
result['response_text'], | |
result['output_audio'], | |
result['processing_times'] | |
) | |
except Exception as e: | |
return f"An unexpected error occurred: {e}", "", "", None | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Asynchronous AI Conversation System") | |
gr.Markdown("Upload an audio file and provide a system prompt to get a response.") | |
with gr.Row(): | |
audio_input = gr.Audio(label="Upload Audio File", type="filepath") | |
system_prompt_input = gr.Textbox(label="System Prompt", value=system_prompt_0) | |
process_button = gr.Button("Process Conversation") | |
with gr.Column(): | |
status_output = gr.Textbox(label="Status/Transcription/Thinking", interactive=False) | |
response_text_output = gr.Textbox(label="AI Response Text", interactive=False) | |
response_audio_output = gr.Audio(label="AI Response Audio", interactive=False) | |
processing_times_output = gr.JSON(label="Processing Times") | |
# Link button click to the async function | |
process_button.click( | |
fn=process_audio_gradio, | |
inputs=[audio_input, system_prompt_input], | |
outputs=[status_output, response_text_output, response_audio_output, processing_times_output] | |
) | |
# Launch the Gradio interface | |
# We need to run the Gradio app within an async context if we're using await inside the handler. | |
# However, Gradio's launch already handles the async loop for the button clicks. | |
# The key is that ai_conversation.initialize_models() must be awaited *before* launching Gradio. | |
# Since the notebook already executed the initialization: | |
# ai_conversation = AsyncAIConversation() | |
# await ai_conversation.initialize_models() | |
# We can directly launch the demo. | |
if __name__ == "__main__": | |
# Gradio launch itself runs an event loop. | |
# Ensure ai_conversation is initialized in the notebook before this cell is run. | |
demo.launch(debug=False, share=True) | |