Spaces:
Paused
Paused
# pip install gradio torch torchaudio soundfile snac | |
import torch | |
import torchaudio | |
from snac import SNAC | |
import gradio as gr | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# load SNAC once | |
MODEL = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(DEVICE) | |
def reconstruct(audio_in): | |
""" | |
audio_in is (sample_rate:int, data:np.ndarray) from gr.Audio(type="numpy") | |
returns (24000, np.ndarray) | |
""" | |
if audio_in is None: | |
return None | |
sr, data = audio_in # data: (T,) or (T, C) | |
# convert to mono if stereo | |
if data.ndim == 2 and data.shape[1] > 1: | |
data = data.mean(axis=1) | |
x = torch.from_numpy(data).float().unsqueeze(0) # [1, T] | |
if sr != 24000: | |
x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=24000) | |
x = x.unsqueeze(0).to(DEVICE) # [1, 1, T] | |
with torch.inference_mode(): | |
y_hat, _, _, _, _ = MODEL(x) # [1, 1, T] | |
y = y_hat.squeeze(0).squeeze(0).detach().cpu() | |
y = torch.clamp(y, -1.0, 1.0) # safety clamp | |
return (24000, y.numpy()) | |
with gr.Blocks(title="SNAC Audio Reconstructor") as demo: | |
gr.Markdown("## 🎵 SNAC Audio Reconstructor (24kHz)") | |
gr.Markdown("Upload or record audio. It’ll get resampled to 24kHz, " | |
"mono-ized, then passed through SNAC for reconstruction.") | |
with gr.Row(): | |
with gr.Column(): | |
audio_in = gr.Audio( | |
sources=["upload", "microphone"], | |
type="numpy", | |
label="Input audio" | |
) | |
btn = gr.Button("Reconstruct") | |
with gr.Column(): | |
audio_out = gr.Audio( | |
type="numpy", | |
label="Reconstructed audio (24kHz)" | |
) | |
btn.click(fn=reconstruct, inputs=audio_in, outputs=audio_out) | |
if __name__ == "__main__": | |
demo.launch() | |