SNAC / app.py
mrfakename's picture
Update app.py
9e72440 verified
# pip install gradio torch torchaudio snac
import torch
import torchaudio
import gradio as gr
from snac import SNAC
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TARGET_SR = 32000 # using the 32 kHz model per your example
MODEL = SNAC.from_pretrained("hubertsiuzdak/snac_32khz").eval().to(DEVICE)
def encode_then_decode(audio_in):
if audio_in is None:
return None
sr, data = audio_in # data: (T,) mono or (T, C) stereo
# mono-ize if needed
if data.ndim == 2:
data = data.mean(axis=1)
# torchify to [1, T]
x = torch.from_numpy(data).float().unsqueeze(0)
# resample to model's target SR
if sr != TARGET_SR:
x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=TARGET_SR)
# expand to [B, 1, T] then encode->decode
x = x.unsqueeze(0).to(DEVICE) # [1, 1, T]
with torch.inference_mode():
codes = MODEL.encode(x)
y = MODEL.decode(codes) # [1, 1, T]
y = y.squeeze().detach().cpu().numpy()
return (TARGET_SR, y)
with gr.Blocks(title="SNAC Encode→Decode (Simple)") as demo:
gr.Markdown("## 🎧 SNAC Encode → Decode (32 kHz)\nResample → `encode()` → `decode()` — that’s it.")
with gr.Row():
with gr.Column():
audio_in = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Input audio")
run = gr.Button("Encode + Decode")
with gr.Column():
audio_out = gr.Audio(type="numpy", label="Reconstructed (32 kHz)")
run.click(encode_then_decode, inputs=audio_in, outputs=audio_out)
if __name__ == "__main__":
demo.launch()