File size: 1,854 Bytes
34ffbd5
 
17fb016
 
 
34ffbd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17fb016
34ffbd5
 
17fb016
34ffbd5
 
17fb016
34ffbd5
17fb016
 
34ffbd5
 
 
 
17fb016
34ffbd5
 
 
 
 
 
 
 
17fb016
34ffbd5
 
 
 
 
17fb016
34ffbd5
17fb016
34ffbd5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# pip install gradio torch torchaudio soundfile snac

import torch
import torchaudio
from snac import SNAC
import gradio as gr

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# load SNAC once
MODEL = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(DEVICE)

def reconstruct(audio_in):
    """
    audio_in is (sample_rate:int, data:np.ndarray) from gr.Audio(type="numpy")
    returns (24000, np.ndarray)
    """
    if audio_in is None:
        return None

    sr, data = audio_in  # data: (T,) or (T, C)

    # convert to mono if stereo
    if data.ndim == 2 and data.shape[1] > 1:
        data = data.mean(axis=1)

    x = torch.from_numpy(data).float().unsqueeze(0)  # [1, T]

    if sr != 24000:
        x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=24000)

    x = x.unsqueeze(0).to(DEVICE)  # [1, 1, T]

    with torch.inference_mode():
        y_hat, _, _, _, _ = MODEL(x)  # [1, 1, T]

    y = y_hat.squeeze(0).squeeze(0).detach().cpu()
    y = torch.clamp(y, -1.0, 1.0)  # safety clamp

    return (24000, y.numpy())


with gr.Blocks(title="SNAC Audio Reconstructor") as demo:
    gr.Markdown("## 🎵 SNAC Audio Reconstructor (24kHz)")
    gr.Markdown("Upload or record audio. It’ll get resampled to 24kHz, "
                "mono-ized, then passed through SNAC for reconstruction.")

    with gr.Row():
        with gr.Column():
            audio_in = gr.Audio(
                sources=["upload", "microphone"],
                type="numpy",
                label="Input audio"
            )
            btn = gr.Button("Reconstruct")

        with gr.Column():
            audio_out = gr.Audio(
                type="numpy",
                label="Reconstructed audio (24kHz)"
            )

    btn.click(fn=reconstruct, inputs=audio_in, outputs=audio_out)

if __name__ == "__main__":
    demo.launch()