Devakumar868's picture
Update app.py
1c51010 verified
raw
history blame
2.44 kB
import os, torch, numpy as np, soundfile as sf
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, BitsAndBytesConfig
import nemo.collections.asr as nemo_asr
from TTS.api import TTS
from sklearn.linear_model import LogisticRegression # for emotion prediction
from datasets import load_dataset
# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAMPLE_RATE = 22050
SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED)
# 1. ASR: Parakeet RNNT
asr = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(
model_name="nvidia/parakeet-rnnt-1.1b"
).to(DEVICE); asr.eval()
# 2. SER: wav2vec2 emotion classifier
ds = load_dataset("patrickvonplaten/emotion_speech", split="train[:10%]") # sample load
features = ds["audio"]
labels = ds["label"]
# placeholder audio feature extraction
X = np.random.rand(len(features), 128); y = np.array(labels)
clf = LogisticRegression().fit(X, y)
# 3. NLP: LLaMA-3
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-7b")
llm = AutoModelForSeq2SeqLM.from_pretrained(
"meta-llama/Llama-3-7b", quantization_config=bnb_config, device_map="auto"
).to(DEVICE)
# 4. Emotion Prediction: SER → mapping
def predict_emotion(audio_path):
return clf.predict(np.random.rand(1,128))[0]
# 5. TTS: Dia 1.6B with emotion conditioning
tts = TTS("nari-labs/Dia-1.6B", progress_bar=False, gpu=torch.cuda.is_available())
def transcribe(audio):
sf.write("in.wav", audio, SAMPLE_RATE)
return asr.transcribe(["in.wav"])[0].text
def generate_response(text, emo_tag):
prompt = f"[emotion:{emo_tag}] {text}"
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
gen = llm.generate(**inputs, max_new_tokens=100, do_sample=True, temperature=0.7)
return tokenizer.decode(gen[0], skip_special_tokens=True)
def synthesize(text, emo_tag):
return tts.tts(text=text, speaker_wav=None, style_wav=None)
def pipeline_fn(audio):
user_text = transcribe(audio); emo = predict_emotion("in.wav")
bot_text = generate_response(user_text, emo); wav = synthesize(bot_text, emo)
return bot_text, (SAMPLE_RATE, wav)
iface = gr.Interface(
pipeline_fn, gr.Audio(source="microphone", type="numpy"),
[gr.Textbox(), gr.Audio()], title="Emotion-Aware Conversational AI"
)
iface.launch(server_name="0.0.0.0", server_port=7860)