SNAC / app.py
mrfakename's picture
Update app.py
34ffbd5 verified
raw
history blame
1.85 kB
# pip install gradio torch torchaudio soundfile snac
import torch
import torchaudio
from snac import SNAC
import gradio as gr
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# load SNAC once
MODEL = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(DEVICE)
def reconstruct(audio_in):
"""
audio_in is (sample_rate:int, data:np.ndarray) from gr.Audio(type="numpy")
returns (24000, np.ndarray)
"""
if audio_in is None:
return None
sr, data = audio_in # data: (T,) or (T, C)
# convert to mono if stereo
if data.ndim == 2 and data.shape[1] > 1:
data = data.mean(axis=1)
x = torch.from_numpy(data).float().unsqueeze(0) # [1, T]
if sr != 24000:
x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=24000)
x = x.unsqueeze(0).to(DEVICE) # [1, 1, T]
with torch.inference_mode():
y_hat, _, _, _, _ = MODEL(x) # [1, 1, T]
y = y_hat.squeeze(0).squeeze(0).detach().cpu()
y = torch.clamp(y, -1.0, 1.0) # safety clamp
return (24000, y.numpy())
with gr.Blocks(title="SNAC Audio Reconstructor") as demo:
gr.Markdown("## 🎵 SNAC Audio Reconstructor (24kHz)")
gr.Markdown("Upload or record audio. It’ll get resampled to 24kHz, "
"mono-ized, then passed through SNAC for reconstruction.")
with gr.Row():
with gr.Column():
audio_in = gr.Audio(
sources=["upload", "microphone"],
type="numpy",
label="Input audio"
)
btn = gr.Button("Reconstruct")
with gr.Column():
audio_out = gr.Audio(
type="numpy",
label="Reconstructed audio (24kHz)"
)
btn.click(fn=reconstruct, inputs=audio_in, outputs=audio_out)
if __name__ == "__main__":
demo.launch()