import os import tempfile from fastapi import FastAPI, UploadFile, File import gradio as gr import nemo.collections.asr as nemo_asr from speechbrain.pretrained import EncoderClassifier from transformers import AutoModelForCausalLM, AutoTokenizer import soundfile as sf import torch import numpy as np from typing import Dict, List, Tuple import json import uuid from datetime import datetime # Initialize FastAPI app app = FastAPI() # Global variables for models asr_model = None emotion_model = None llm_model = None llm_tokenizer = None conversation_history = {} def load_models(): """Load all required models""" global asr_model, emotion_model, llm_model, llm_tokenizer try: # Load ASR model using correct syntax print("Loading ASR model...") asr_model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2") print("ASR model loaded successfully") # Load emotion recognition model print("Loading emotion model...") emotion_model = EncoderClassifier.from_hparams( source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", savedir="./emotion_model_cache" ) print("Emotion model loaded successfully") # Load LLM for conversation print("Loading LLM...") model_name = "microsoft/DialoGPT-medium" # Lighter alternative to Vicuna llm_tokenizer = AutoTokenizer.from_pretrained(model_name) llm_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None ) # Add padding token if not present if llm_tokenizer.pad_token is None: llm_tokenizer.pad_token = llm_tokenizer.eos_token print("All models loaded successfully") except Exception as e: print(f"Error loading models: {str(e)}") raise e def transcribe_audio(audio_path: str) -> Tuple[str, str]: """Transcribe audio and detect emotion""" try: # ASR transcription transcription = asr_model.transcribe([audio_path]) text = transcription[0].text if hasattr(transcription[0], 'text') else str(transcription[0]) # Emotion detection emotion_result = emotion_model.classify_file(audio_path) emotion = emotion_result[0] if isinstance(emotion_result, list) else str(emotion_result) return text, emotion except Exception as e: print(f"Error in transcription: {str(e)}") return f"Error: {str(e)}", "unknown" def generate_response(user_text: str, emotion: str, user_id: str) -> str: """Generate contextual response based on user input and emotion""" try: # Get conversation history if user_id not in conversation_history: conversation_history[user_id] = [] # Add emotion context to the input emotional_context = f"[User is feeling {emotion}] {user_text}" # Encode input with conversation history conversation_history[user_id].append(emotional_context) # Keep only last 5 exchanges to manage memory if len(conversation_history[user_id]) > 10: conversation_history[user_id] = conversation_history[user_id][-10:] # Create input for the model input_text = " ".join(conversation_history[user_id][-3:]) # Last 3 exchanges # Tokenize and generate inputs = llm_tokenizer.encode(input_text, return_tensors="pt") if torch.cuda.is_available(): inputs = inputs.cuda() with torch.no_grad(): outputs = llm_model.generate( inputs, max_new_tokens=100, num_return_sequences=1, temperature=0.7, do_sample=True, pad_token_id=llm_tokenizer.eos_token_id ) # Decode response response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the new part of the response response = response[len(input_text):].strip() # Add to conversation history conversation_history[user_id].append(response) return response if response else "I understand your feelings. How can I help you today?" except Exception as e: print(f"Error generating response: {str(e)}") return "I'm having trouble processing that right now. Could you try again?" def process_audio_input(audio_file, user_id: str = None) -> Tuple[str, str, str, str]: """Main processing function for audio input""" if user_id is None: user_id = str(uuid.uuid4()) if audio_file is None: return "No audio file provided", "", "", user_id try: # Save uploaded audio to temporary file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: # Handle different audio input formats if hasattr(audio_file, 'name'): # File upload case audio_path = audio_file.name else: # Direct audio data case sf.write(tmp_file.name, audio_file[1], audio_file[0]) audio_path = tmp_file.name # Process audio transcription, emotion = transcribe_audio(audio_path) # Generate response response = generate_response(transcription, emotion, user_id) # Clean up temporary file if audio_path != (audio_file.name if hasattr(audio_file, 'name') else ''): os.unlink(audio_path) return transcription, emotion, response, user_id except Exception as e: error_msg = f"Processing error: {str(e)}" print(error_msg) return error_msg, "error", "I'm sorry, I couldn't process your audio.", user_id def get_conversation_history(user_id: str) -> str: """Get formatted conversation history for a user""" if user_id not in conversation_history or not conversation_history[user_id]: return "No conversation history yet." history = conversation_history[user_id] formatted_history = [] for i in range(0, len(history), 2): if i + 1 < len(history): user_msg = history[i].replace(f"[User is feeling ", "").split("] ", 1)[-1] bot_msg = history[i + 1] formatted_history.append(f"**You:** {user_msg}") formatted_history.append(f"**AI:** {bot_msg}") return "\n\n".join(formatted_history) if formatted_history else "No conversation history yet." def clear_conversation(user_id: str) -> str: """Clear conversation history for a user""" if user_id in conversation_history: conversation_history[user_id] = [] return "Conversation history cleared." # Load models on startup print("Initializing models...") load_models() print("Models initialized successfully") # Create Gradio interface with gr.Blocks(title="Emotional Conversational AI", theme=gr.themes.Soft()) as iface: gr.Markdown("# 🎤 Emotional Conversational AI") gr.Markdown("Upload audio or use your microphone to have an emotional conversation with AI") # User ID state user_id_state = gr.State(value=str(uuid.uuid4())) with gr.Row(): with gr.Column(scale=2): # Audio input audio_input = gr.Audio( sources=["microphone", "upload"], type="filepath", label="🎙️ Record or Upload Audio" ) # Process button process_btn = gr.Button("🚀 Process Audio", variant="primary", size="lg") with gr.Column(scale=3): # Output displays transcription_output = gr.Textbox( label="📝 Transcription", placeholder="Your speech will appear here...", max_lines=3 ) emotion_output = gr.Textbox( label="😊 Detected Emotion", placeholder="Detected emotion will appear here...", max_lines=1 ) response_output = gr.Textbox( label="🤖 AI Response", placeholder="AI response will appear here...", max_lines=5 ) with gr.Row(): with gr.Column(): # Conversation history history_output = gr.Textbox( label="💬 Conversation History", placeholder="Your conversation history will appear here...", max_lines=10, interactive=False ) with gr.Column(): # Control buttons show_history_btn = gr.Button("📖 Show History", variant="secondary") clear_history_btn = gr.Button("🗑️ Clear History", variant="stop") new_session_btn = gr.Button("🆕 New Session", variant="secondary") # Event handlers process_btn.click( fn=process_audio_input, inputs=[audio_input, user_id_state], outputs=[transcription_output, emotion_output, response_output, user_id_state] ) show_history_btn.click( fn=get_conversation_history, inputs=[user_id_state], outputs=[history_output] ) clear_history_btn.click( fn=clear_conversation, inputs=[user_id_state], outputs=[history_output] ) new_session_btn.click( fn=lambda: (str(uuid.uuid4()), "New session started!"), outputs=[user_id_state, history_output] ) # Mount Gradio app to FastAPI app = gr.mount_gradio_app(app, iface, path="/") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)