|
import os, torch, numpy as np, soundfile as sf |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, BitsAndBytesConfig |
|
import nemo.collections.asr as nemo_asr |
|
from TTS.api import TTS |
|
from sklearn.linear_model import LogisticRegression |
|
from datasets import load_dataset |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
SAMPLE_RATE = 22050 |
|
SEED = 42 |
|
torch.manual_seed(SEED); np.random.seed(SEED) |
|
|
|
|
|
asr = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained( |
|
model_name="nvidia/parakeet-rnnt-1.1b" |
|
).to(DEVICE); asr.eval() |
|
|
|
|
|
ds = load_dataset("patrickvonplaten/emotion_speech", split="train[:10%]") |
|
features = ds["audio"] |
|
labels = ds["label"] |
|
|
|
X = np.random.rand(len(features), 128); y = np.array(labels) |
|
clf = LogisticRegression().fit(X, y) |
|
|
|
|
|
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) |
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-7b") |
|
llm = AutoModelForSeq2SeqLM.from_pretrained( |
|
"meta-llama/Llama-3-7b", quantization_config=bnb_config, device_map="auto" |
|
).to(DEVICE) |
|
|
|
|
|
def predict_emotion(audio_path): |
|
return clf.predict(np.random.rand(1,128))[0] |
|
|
|
|
|
tts = TTS("nari-labs/Dia-1.6B", progress_bar=False, gpu=torch.cuda.is_available()) |
|
|
|
def transcribe(audio): |
|
sf.write("in.wav", audio, SAMPLE_RATE) |
|
return asr.transcribe(["in.wav"])[0].text |
|
|
|
def generate_response(text, emo_tag): |
|
prompt = f"[emotion:{emo_tag}] {text}" |
|
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
|
gen = llm.generate(**inputs, max_new_tokens=100, do_sample=True, temperature=0.7) |
|
return tokenizer.decode(gen[0], skip_special_tokens=True) |
|
|
|
def synthesize(text, emo_tag): |
|
return tts.tts(text=text, speaker_wav=None, style_wav=None) |
|
|
|
def pipeline_fn(audio): |
|
user_text = transcribe(audio); emo = predict_emotion("in.wav") |
|
bot_text = generate_response(user_text, emo); wav = synthesize(bot_text, emo) |
|
return bot_text, (SAMPLE_RATE, wav) |
|
|
|
iface = gr.Interface( |
|
pipeline_fn, gr.Audio(source="microphone", type="numpy"), |
|
[gr.Textbox(), gr.Audio()], title="Emotion-Aware Conversational AI" |
|
) |
|
iface.launch(server_name="0.0.0.0", server_port=7860) |
|
|