Maya-AI / app.py
Devakumar868's picture
Update app.py
c16d7ae verified
raw
history blame
3.81 kB
import os
import gradio as gr
import torch
import numpy as np
from transformers import pipeline
from pyannote.audio import Pipeline as PyannotePipeline
from dia.model import Dia
from dac.utils import load_model as load_dac_model
# Environment setup
HF_TOKEN = os.environ["HF_TOKEN"]
device_map = "auto"
print("Loading models...")
# 1. Load RVQ Codec
print("Loading RVQ Codec...")
rvq = load_dac_model(tag="latest", model_type="44khz")
rvq.eval()
if torch.cuda.is_available():
rvq = rvq.to("cuda")
# 2. Load VAD Pipeline
print("Loading VAD...")
vad_pipe = PyannotePipeline.from_pretrained(
"pyannote/voice-activity-detection",
use_auth_token=HF_TOKEN
)
# 3. Load Ultravox Pipeline
print("Loading Ultravox...")
ultravox_pipe = pipeline(
model="fixie-ai/ultravox-v0_4",
trust_remote_code=True,
device_map=device_map,
torch_dtype=torch.float16
)
# 4. Skip Audio Diffusion (causing UNet mismatch)
print("Skipping Audio Diffusion due to compatibility issues...")
diff_pipe = None
# 5. Load Dia TTS (correct method based on current API)
print("Loading Dia TTS...")
dia = Dia.from_pretrained("nari-labs/Dia-1.6B")
print("All models loaded successfully!")
def process_audio(audio):
try:
if audio is None:
return None, "No audio input provided"
sr, array = audio
# Ensure numpy array
if torch.is_tensor(array):
array = array.numpy()
# VAD processing
try:
vad_result = vad_pipe({"waveform": torch.tensor(array).unsqueeze(0), "sample_rate": sr})
except Exception as e:
print(f"VAD processing error: {e}")
# RVQ encode/decode
audio_tensor = torch.tensor(array).unsqueeze(0)
if torch.cuda.is_available():
audio_tensor = audio_tensor.to("cuda")
codes = rvq.encode(audio_tensor)
decoded = rvq.decode(codes).squeeze().cpu().numpy()
# Ultravox ASR + LLM
ultra_out = ultravox_pipe({"array": decoded, "sampling_rate": sr})
text = ultra_out.get("text", "I understand your audio input.")
# Skip diffusion processing due to compatibility issues
prosody_audio = decoded
# Dia TTS generation
tts_output = dia.generate(f"[emotion:neutral] {text}")
# Convert to numpy and normalize
if torch.is_tensor(tts_output):
tts_np = tts_output.squeeze().cpu().numpy()
else:
tts_np = np.array(tts_output)
# Normalize audio output
if len(tts_np) > 0:
tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95
return (sr, tts_np), text
except Exception as e:
print(f"Error in process_audio: {e}")
return None, f"Processing error: {str(e)}"
# Gradio Interface
with gr.Blocks(title="Maya AI πŸ“ˆ") as demo:
gr.Markdown("# Maya-AI: Supernatural Conversational Agent")
gr.Markdown("Record audio to interact with the AI agent that understands emotions and responds naturally.")
with gr.Row():
with gr.Column():
audio_in = gr.Audio(
sources=["microphone"],
type="numpy",
label="Record Your Voice"
)
send_btn = gr.Button("Send", variant="primary")
with gr.Column():
audio_out = gr.Audio(label="AI Response")
text_out = gr.Textbox(
label="Generated Text",
lines=3,
placeholder="AI response will appear here..."
)
# Event handler
send_btn.click(
fn=process_audio,
inputs=audio_in,
outputs=[audio_out, text_out]
)
if __name__ == "__main__":
demo.launch()