|
import os, torch, numpy as np, soundfile as sf, gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline |
|
import nemo.collections.asr as nemo_asr |
|
from TTS.api import TTS |
|
from sklearn.linear_model import LogisticRegression |
|
from datasets import load_dataset |
|
import tempfile |
|
import gc |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
SEED = 42; SAMPLE_RATE = 22050; TEMPERATURE = 0.7 |
|
torch.manual_seed(SEED); np.random.seed(SEED) |
|
|
|
print(f"π System Info:") |
|
print(f"Device: {DEVICE}") |
|
print(f"NumPy: {np.__version__}") |
|
print(f"PyTorch: {torch.__version__}") |
|
if torch.cuda.is_available(): |
|
print(f"CUDA: {torch.version.cuda}") |
|
|
|
class ConversationalAI: |
|
def __init__(self): |
|
print("π Initializing Conversational AI...") |
|
self.setup_models() |
|
print("β
All models loaded successfully!") |
|
|
|
def setup_models(self): |
|
|
|
print("π’ Loading ASR model...") |
|
try: |
|
self.asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained( |
|
"nvidia/parakeet-rnnt-1.1b" |
|
).to(DEVICE).eval() |
|
print("β
Parakeet ASR loaded") |
|
except Exception as e: |
|
print(f"β οΈ Parakeet failed: {e}") |
|
print("π Loading Whisper fallback...") |
|
self.asr_pipeline = pipeline( |
|
"automatic-speech-recognition", |
|
model="openai/whisper-base.en", |
|
device=0 if DEVICE == "cuda" else -1 |
|
) |
|
print("β
Whisper ASR loaded") |
|
|
|
|
|
print("π Setting up emotion recognition...") |
|
X_demo = np.random.rand(100, 128) |
|
y_demo = np.random.randint(0, 5, 100) |
|
self.ser_clf = LogisticRegression().fit(X_demo, y_demo) |
|
self.emotion_labels = ["neutral", "happy", "sad", "angry", "surprised"] |
|
print("β
SER model ready") |
|
|
|
|
|
print("π§ Loading LLM...") |
|
bnb_cfg = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4" |
|
) |
|
|
|
model_name = "microsoft/DialoGPT-medium" |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
self.llm_model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
quantization_config=bnb_cfg, |
|
device_map="auto", |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True |
|
) |
|
print("β
LLM loaded") |
|
|
|
|
|
print("π£οΈ Loading TTS...") |
|
try: |
|
self.tts = TTS("tts_models/en/ljspeech/tacotron2-DDC").to(DEVICE) |
|
print("β
TTS loaded") |
|
except Exception as e: |
|
print(f"β οΈ TTS error: {e}") |
|
self.tts = None |
|
|
|
|
|
if DEVICE == "cuda": |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
def transcribe(self, audio): |
|
"""Convert speech to text""" |
|
try: |
|
if hasattr(self, 'asr_model'): |
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
|
sf.write(temp_file.name, audio[1], audio[0]) |
|
transcription = self.asr_model.transcribe([temp_file.name])[0] |
|
os.unlink(temp_file.name) |
|
return transcription.text if hasattr(transcription, 'text') else str(transcription) |
|
else: |
|
|
|
return self.asr_pipeline({"sampling_rate": audio[0], "raw": audio[1]})["text"] |
|
except Exception as e: |
|
print(f"ASR Error: {e}") |
|
return "Sorry, I couldn't understand the audio." |
|
|
|
def predict_emotion(self): |
|
"""Predict emotion from audio (simplified demo)""" |
|
emotion_idx = self.ser_clf.predict(np.random.rand(1, 128))[0] |
|
return self.emotion_labels[emotion_idx] |
|
|
|
def generate_response(self, text, emotion): |
|
"""Generate conversational response""" |
|
try: |
|
|
|
prompt = f"Human: {text}\nAssistant (feeling {emotion}):" |
|
|
|
inputs = self.tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True).to(DEVICE) |
|
|
|
with torch.no_grad(): |
|
outputs = self.llm_model.generate( |
|
inputs, |
|
max_length=inputs.shape[1] + 100, |
|
temperature=TEMPERATURE, |
|
do_sample=True, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
no_repeat_ngram_size=2, |
|
top_p=0.9 |
|
) |
|
|
|
response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) |
|
response = response.split("Human:")[0].strip() |
|
|
|
return response if response else "I understand. Please tell me more." |
|
except Exception as e: |
|
print(f"LLM Error: {e}") |
|
return "I'm having trouble processing that. Could you please rephrase?" |
|
|
|
def synthesize(self, text): |
|
"""Convert text to speech""" |
|
try: |
|
if self.tts: |
|
wav = self.tts.tts(text=text) |
|
if isinstance(wav, list): |
|
wav = np.array(wav, dtype=np.float32) |
|
|
|
wav = wav / np.max(np.abs(wav)) if np.max(np.abs(wav)) > 0 else wav |
|
return (SAMPLE_RATE, (wav * 32767).astype(np.int16)) |
|
else: |
|
|
|
return (SAMPLE_RATE, np.zeros(SAMPLE_RATE, dtype=np.int16)) |
|
except Exception as e: |
|
print(f"TTS Error: {e}") |
|
return (SAMPLE_RATE, np.zeros(SAMPLE_RATE, dtype=np.int16)) |
|
|
|
def process_conversation(self, audio_input, chat_history): |
|
"""Main pipeline: Speech -> Emotion -> LLM -> Speech""" |
|
if audio_input is None: |
|
return chat_history, None, "" |
|
|
|
try: |
|
|
|
user_text = self.transcribe(audio_input) |
|
if not user_text.strip(): |
|
return chat_history, None, "No speech detected." |
|
|
|
|
|
emotion = self.predict_emotion() |
|
|
|
|
|
ai_response = self.generate_response(user_text, emotion) |
|
|
|
|
|
audio_response = self.synthesize(ai_response) |
|
|
|
|
|
chat_history.append([user_text, ai_response]) |
|
|
|
|
|
if DEVICE == "cuda": |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
return chat_history, audio_response, f"You said: {user_text} (detected emotion: {emotion})" |
|
|
|
except Exception as e: |
|
error_msg = f"Error processing conversation: {e}" |
|
print(error_msg) |
|
return chat_history, None, error_msg |
|
|
|
|
|
print("π Starting Conversational AI...") |
|
ai_system = ConversationalAI() |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks( |
|
title="Emotion-Aware Conversational AI", |
|
theme=gr.themes.Soft() |
|
) as demo: |
|
|
|
gr.HTML(""" |
|
<div style="text-align: center; margin-bottom: 2rem;"> |
|
<h1>π€ Emotion-Aware Conversational AI</h1> |
|
<p>Speak naturally and get intelligent responses with emotion recognition</p> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot( |
|
label="Conversation History", |
|
height=400, |
|
show_copy_button=True |
|
) |
|
|
|
audio_input = gr.Audio( |
|
label="π€ Speak to AI", |
|
sources=["microphone"], |
|
type="numpy", |
|
format="wav" |
|
) |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("π¬ Process Speech", variant="primary", scale=2) |
|
clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary", scale=1) |
|
|
|
with gr.Column(scale=1): |
|
audio_output = gr.Audio( |
|
label="π AI Response", |
|
type="numpy", |
|
autoplay=True |
|
) |
|
|
|
status_display = gr.Textbox( |
|
label="π Status", |
|
lines=3, |
|
interactive=False |
|
) |
|
|
|
gr.HTML(f""" |
|
<div style="padding: 1rem; background: #f0f9ff; border-radius: 0.5rem;"> |
|
<h3>π§ System Info</h3> |
|
<p><strong>Device:</strong> {DEVICE.upper()}</p> |
|
<p><strong>PyTorch:</strong> {torch.__version__}</p> |
|
<p><strong>Models:</strong> Parakeet + DialoGPT + TTS</p> |
|
<p><strong>Features:</strong> Emotion Recognition</p> |
|
</div> |
|
""") |
|
|
|
def process_audio(audio, history): |
|
return ai_system.process_conversation(audio, history) |
|
|
|
def clear_conversation(): |
|
if DEVICE == "cuda": |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
return [], None, "Conversation cleared." |
|
|
|
|
|
submit_btn.click( |
|
fn=process_audio, |
|
inputs=[audio_input, chatbot], |
|
outputs=[chatbot, audio_output, status_display] |
|
) |
|
|
|
clear_btn.click( |
|
fn=clear_conversation, |
|
outputs=[chatbot, audio_output, status_display] |
|
) |
|
|
|
audio_input.change( |
|
fn=process_audio, |
|
inputs=[audio_input, chatbot], |
|
outputs=[chatbot, audio_output, status_display] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
print("π Creating interface...") |
|
demo = create_interface() |
|
|
|
print("π Launching application...") |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=True, |
|
show_error=True |
|
) |
|
|