Spaces:
Paused
Paused
# pip install gradio torch torchaudio snac | |
import torch | |
import torchaudio | |
import gradio as gr | |
from snac import SNAC | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
TARGET_SR = 32000 # using the 32 kHz model per your example | |
MODEL = SNAC.from_pretrained("hubertsiuzdak/snac_32khz").eval().to(DEVICE) | |
def encode_then_decode(audio_in): | |
if audio_in is None: | |
return None | |
sr, data = audio_in # data: (T,) mono or (T, C) stereo | |
# mono-ize if needed | |
if data.ndim == 2: | |
data = data.mean(axis=1) | |
# torchify to [1, T] | |
x = torch.from_numpy(data).float().unsqueeze(0) | |
# resample to model's target SR | |
if sr != TARGET_SR: | |
x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=TARGET_SR) | |
# expand to [B, 1, T] then encode->decode | |
x = x.unsqueeze(0).to(DEVICE) # [1, 1, T] | |
with torch.inference_mode(): | |
codes = MODEL.encode(x) | |
y = MODEL.decode(codes) # [1, 1, T] | |
y = y.squeeze().detach().cpu().numpy() | |
return (TARGET_SR, y) | |
with gr.Blocks(title="SNAC Encode→Decode (Simple)") as demo: | |
gr.Markdown("## 🎧 SNAC Encode → Decode (32 kHz)\nResample → `encode()` → `decode()` — that’s it.") | |
with gr.Row(): | |
with gr.Column(): | |
audio_in = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Input audio") | |
run = gr.Button("Encode + Decode") | |
with gr.Column(): | |
audio_out = gr.Audio(type="numpy", label="Reconstructed (32 kHz)") | |
run.click(encode_then_decode, inputs=audio_in, outputs=audio_out) | |
if __name__ == "__main__": | |
demo.launch() | |