File size: 2,786 Bytes
6e55da8
c5ef34e
653911d
 
6e55da8
036f56f
1a24747
 
 
5adc99b
036f56f
6e55da8
42e6e01
036f56f
42e6e01
5adc99b
036f56f
1a24747
 
036f56f
 
 
 
 
 
42e6e01
1a24747
036f56f
1a24747
 
 
 
 
 
 
036f56f
1a24747
 
 
 
 
 
 
 
036f56f
1a24747
42e6e01
 
 
 
 
1a24747
42e6e01
 
 
 
1a24747
 
036f56f
653911d
6e55da8
036f56f
 
6e55da8
036f56f
 
6e55da8
 
 
 
 
 
036f56f
6e55da8
 
 
036f56f
 
6e55da8
036f56f
 
 
 
6e55da8
036f56f
6e55da8
036f56f
 
6e55da8
 
036f56f
6e55da8
 
036f56f
653911d
 
1a24747
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import gradio as gr
import torch
import numpy as np
from transformers import pipeline
from pyannote.audio import Pipeline as PyannotePipeline
from dia.model import Dia
from dac.utils import load_model as load_dac_model
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

# Environment token
HF_TOKEN = os.environ["HF_TOKEN"]

# Shard large models across 4ร— L4 GPUs
device_map = "auto"

# 1. RVQ codec (Descript Audio Codec)
rvq = load_dac_model(tag="latest", model_type="44khz")
rvq.eval()
if torch.cuda.is_available(): rvq = rvq.to("cuda")

# 2. Voice Activity Detection via Pyannote
vad_pipe = PyannotePipeline.from_pretrained(
    "pyannote/voice-activity-detection",
    use_auth_token=HF_TOKEN
)

# 3. Ultravox pipeline (speech โ†’ text + LLM)
ultravox_pipe = pipeline(
    model="fixie-ai/ultravox-v0_4",
    trust_remote_code=True,
    device_map=device_map,
    torch_dtype=torch.float16
)

# 4. Diffusion-based prosody model
diff_pipe = pipeline(
    "audio-to-audio",
    model="teticio/audio-diffusion-instrumental-hiphop-256",
    trust_remote_code=True,
    device_map=device_map,
    torch_dtype=torch.float16
)

# 5. Dia TTS loaded with multi-GPU dispatch
with init_empty_weights():
    dia = Dia.from_pretrained(
        "nari-labs/Dia-1.6B",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
dia = load_checkpoint_and_dispatch(
    dia,
    "nari-labs/Dia-1.6B",
    device_map=device_map,
    dtype=torch.float16
)

# Inference function
def process_audio(audio):
    sr, array = audio
    # Ensure numpy
    if torch.is_tensor(array): array = array.numpy()

    # VAD: extract speech regions
    chunks = vad_pipe(array, sampling_rate=sr)

    # RVQ encode/decode
    x = torch.tensor(array).unsqueeze(0).to("cuda")
    codes = rvq.encode(x)
    decoded = rvq.decode(codes).squeeze().cpu().numpy()

    # Ultravox ASR + LLM
    out = ultravox_pipe({"array": decoded, "sampling_rate": sr})
    text = out.get("text", "")

    # Diffusion prosody enhancement
    pros_audio = diff_pipe({"array": decoded, "sampling_rate": sr})["array"][0]

    # Dia TTS synthesis
    tts = dia.generate(f"[emotion:neutral] {text}")
    tts_np = tts.squeeze().cpu().numpy()
    tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95

    return (sr, tts_np), text

# Gradio UI
with gr.Blocks(title="Maya AI ๐Ÿ“ˆ", theme=None) as demo:
    gr.Markdown("## Maya-AI: Supernatural Conversational Agent")
    audio_in = gr.Audio(source="microphone", type="numpy", label="Your Voice")
    send_btn = gr.Button("Send")
    audio_out = gr.Audio(label="AIโ€™s Response")
    text_out = gr.Textbox(label="Generated Text")
    send_btn.click(process_audio, inputs=audio_in, outputs=[audio_out, text_out])

if __name__ == "__main__":
    demo.launch()