Devakumar868's picture
Update app.py
322ba51 verified
raw
history blame
10.8 kB
import os, torch, numpy as np, soundfile as sf, gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
import nemo.collections.asr as nemo_asr
from TTS.api import TTS
from sklearn.linear_model import LogisticRegression
from datasets import load_dataset
import tempfile
import gc
# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42; SAMPLE_RATE = 22050; TEMPERATURE = 0.7
torch.manual_seed(SEED); np.random.seed(SEED)
print(f"πŸš€ System Info:")
print(f"Device: {DEVICE}")
print(f"NumPy: {np.__version__}")
print(f"PyTorch: {torch.__version__}")
if torch.cuda.is_available():
print(f"CUDA: {torch.version.cuda}")
class ConversationalAI:
def __init__(self):
print("πŸ”„ Initializing Conversational AI...")
self.setup_models()
print("βœ… All models loaded successfully!")
def setup_models(self):
# 1. ASR: Parakeet RNNT
print("πŸ“’ Loading ASR model...")
try:
self.asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(
"nvidia/parakeet-rnnt-1.1b"
).to(DEVICE).eval()
print("βœ… Parakeet ASR loaded")
except Exception as e:
print(f"⚠️ Parakeet failed: {e}")
print("πŸ”„ Loading Whisper fallback...")
self.asr_pipeline = pipeline(
"automatic-speech-recognition",
model="openai/whisper-base.en",
device=0 if DEVICE == "cuda" else -1
)
print("βœ… Whisper ASR loaded")
# 2. SER: Emotion classifier (simplified for demo)
print("🎭 Setting up emotion recognition...")
X_demo = np.random.rand(100, 128)
y_demo = np.random.randint(0, 5, 100) # 5 emotions: neutral, happy, sad, angry, surprised
self.ser_clf = LogisticRegression().fit(X_demo, y_demo)
self.emotion_labels = ["neutral", "happy", "sad", "angry", "surprised"]
print("βœ… SER model ready")
# 3. LLM: Conversational model
print("🧠 Loading LLM...")
bnb_cfg = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
model_name = "microsoft/DialoGPT-medium"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.llm_model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_cfg,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
print("βœ… LLM loaded")
# 4. TTS: Text-to-Speech
print("πŸ—£οΈ Loading TTS...")
try:
self.tts = TTS("tts_models/en/ljspeech/tacotron2-DDC").to(DEVICE)
print("βœ… TTS loaded")
except Exception as e:
print(f"⚠️ TTS error: {e}")
self.tts = None
# Memory cleanup
if DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
def transcribe(self, audio):
"""Convert speech to text"""
try:
if hasattr(self, 'asr_model'):
# Use Parakeet
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
sf.write(temp_file.name, audio[1], audio[0])
transcription = self.asr_model.transcribe([temp_file.name])[0]
os.unlink(temp_file.name)
return transcription.text if hasattr(transcription, 'text') else str(transcription)
else:
# Use Whisper
return self.asr_pipeline({"sampling_rate": audio[0], "raw": audio[1]})["text"]
except Exception as e:
print(f"ASR Error: {e}")
return "Sorry, I couldn't understand the audio."
def predict_emotion(self):
"""Predict emotion from audio (simplified demo)"""
emotion_idx = self.ser_clf.predict(np.random.rand(1, 128))[0]
return self.emotion_labels[emotion_idx]
def generate_response(self, text, emotion):
"""Generate conversational response"""
try:
# Create emotion-aware prompt
prompt = f"Human: {text}\nAssistant (feeling {emotion}):"
inputs = self.tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True).to(DEVICE)
with torch.no_grad():
outputs = self.llm_model.generate(
inputs,
max_length=inputs.shape[1] + 100,
temperature=TEMPERATURE,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
no_repeat_ngram_size=2,
top_p=0.9
)
response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
response = response.split("Human:")[0].strip()
return response if response else "I understand. Please tell me more."
except Exception as e:
print(f"LLM Error: {e}")
return "I'm having trouble processing that. Could you please rephrase?"
def synthesize(self, text):
"""Convert text to speech"""
try:
if self.tts:
wav = self.tts.tts(text=text)
if isinstance(wav, list):
wav = np.array(wav, dtype=np.float32)
# Normalize audio
wav = wav / np.max(np.abs(wav)) if np.max(np.abs(wav)) > 0 else wav
return (SAMPLE_RATE, (wav * 32767).astype(np.int16))
else:
# Return silence if TTS fails
return (SAMPLE_RATE, np.zeros(SAMPLE_RATE, dtype=np.int16))
except Exception as e:
print(f"TTS Error: {e}")
return (SAMPLE_RATE, np.zeros(SAMPLE_RATE, dtype=np.int16))
def process_conversation(self, audio_input, chat_history):
"""Main pipeline: Speech -> Emotion -> LLM -> Speech"""
if audio_input is None:
return chat_history, None, ""
try:
# Step 1: Speech to Text
user_text = self.transcribe(audio_input)
if not user_text.strip():
return chat_history, None, "No speech detected."
# Step 2: Emotion Recognition
emotion = self.predict_emotion()
# Step 3: Generate Response
ai_response = self.generate_response(user_text, emotion)
# Step 4: Text to Speech
audio_response = self.synthesize(ai_response)
# Update chat history
chat_history.append([user_text, ai_response])
# Memory cleanup
if DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
return chat_history, audio_response, f"You said: {user_text} (detected emotion: {emotion})"
except Exception as e:
error_msg = f"Error processing conversation: {e}"
print(error_msg)
return chat_history, None, error_msg
# Initialize AI system
print("πŸš€ Starting Conversational AI...")
ai_system = ConversationalAI()
# Gradio Interface
def create_interface():
with gr.Blocks(
title="Emotion-Aware Conversational AI",
theme=gr.themes.Soft()
) as demo:
gr.HTML("""
<div style="text-align: center; margin-bottom: 2rem;">
<h1>πŸ€– Emotion-Aware Conversational AI</h1>
<p>Speak naturally and get intelligent responses with emotion recognition</p>
</div>
""")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Conversation History",
height=400,
show_copy_button=True
)
audio_input = gr.Audio(
label="🎀 Speak to AI",
sources=["microphone"],
type="numpy",
format="wav"
)
with gr.Row():
submit_btn = gr.Button("πŸ’¬ Process Speech", variant="primary", scale=2)
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary", scale=1)
with gr.Column(scale=1):
audio_output = gr.Audio(
label="πŸ”Š AI Response",
type="numpy",
autoplay=True
)
status_display = gr.Textbox(
label="πŸ“Š Status",
lines=3,
interactive=False
)
gr.HTML(f"""
<div style="padding: 1rem; background: #f0f9ff; border-radius: 0.5rem;">
<h3>πŸ”§ System Info</h3>
<p><strong>Device:</strong> {DEVICE.upper()}</p>
<p><strong>PyTorch:</strong> {torch.__version__}</p>
<p><strong>Models:</strong> Parakeet + DialoGPT + TTS</p>
<p><strong>Features:</strong> Emotion Recognition</p>
</div>
""")
def process_audio(audio, history):
return ai_system.process_conversation(audio, history)
def clear_conversation():
if DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
return [], None, "Conversation cleared."
# Event handlers
submit_btn.click(
fn=process_audio,
inputs=[audio_input, chatbot],
outputs=[chatbot, audio_output, status_display]
)
clear_btn.click(
fn=clear_conversation,
outputs=[chatbot, audio_output, status_display]
)
audio_input.change(
fn=process_audio,
inputs=[audio_input, chatbot],
outputs=[chatbot, audio_output, status_display]
)
return demo
# Launch application
if __name__ == "__main__":
print("🌟 Creating interface...")
demo = create_interface()
print("πŸš€ Launching application...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)