Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -3,50 +3,51 @@ import torchaudio
|
|
3 |
from snac import SNAC
|
4 |
import gradio as gr
|
5 |
|
|
|
|
|
|
|
|
|
6 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
7 |
-
|
8 |
|
9 |
def reconstruct(audio_in):
|
10 |
if audio_in is None:
|
11 |
return None
|
12 |
-
sr, data = audio_in # (sr, np.ndarray)
|
13 |
-
|
14 |
-
# to tensor [channels, T]
|
15 |
-
audio = torch.from_numpy(data.T).float()
|
16 |
|
17 |
-
#
|
18 |
-
if sr != 24000:
|
19 |
-
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=24000)
|
20 |
-
audio = resampler(audio)
|
21 |
|
22 |
-
# stereo → mono
|
23 |
-
if
|
24 |
-
|
25 |
|
26 |
-
#
|
27 |
-
audio =
|
28 |
|
|
|
29 |
with torch.inference_mode():
|
30 |
-
|
31 |
-
audio_hat = out[0] if isinstance(out,(list,tuple)) else out
|
32 |
|
33 |
y = audio_hat.squeeze().cpu().numpy()
|
34 |
-
return (
|
35 |
|
36 |
-
with gr.Blocks() as demo:
|
37 |
-
gr.Markdown("## SNAC Audio Reconstructor (
|
38 |
|
39 |
with gr.Row():
|
40 |
with gr.Column():
|
41 |
-
audio_in = gr.Audio(
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
45 |
with gr.Column():
|
46 |
-
audio_out = gr.Audio(
|
47 |
-
|
|
|
|
|
48 |
|
49 |
-
btn.click(
|
50 |
|
51 |
if __name__ == "__main__":
|
52 |
demo.launch()
|
|
|
3 |
from snac import SNAC
|
4 |
import gradio as gr
|
5 |
|
6 |
+
# pick the right SNAC model for your audio sample rate
|
7 |
+
MODEL_NAME = "hubertsiuzdak/snac_24khz"
|
8 |
+
SR = 24000
|
9 |
+
|
10 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
model = SNAC.from_pretrained(MODEL_NAME).eval().to(DEVICE)
|
12 |
|
13 |
def reconstruct(audio_in):
|
14 |
if audio_in is None:
|
15 |
return None
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
sr, data = audio_in # data: (T,) or (T,C)
|
|
|
|
|
|
|
18 |
|
19 |
+
# convert stereo → mono
|
20 |
+
if data.ndim == 2:
|
21 |
+
data = data.mean(axis=1)
|
22 |
|
23 |
+
# turn into torch [1,1,T]
|
24 |
+
audio = torch.from_numpy(data).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
|
25 |
|
26 |
+
# run through SNAC
|
27 |
with torch.inference_mode():
|
28 |
+
audio_hat, codes = model(audio)
|
|
|
29 |
|
30 |
y = audio_hat.squeeze().cpu().numpy()
|
31 |
+
return (SR, y)
|
32 |
|
33 |
+
with gr.Blocks(title="SNAC Round-Trip Demo") as demo:
|
34 |
+
gr.Markdown("## 🎧 SNAC Audio Reconstructor (minimal!)")
|
35 |
|
36 |
with gr.Row():
|
37 |
with gr.Column():
|
38 |
+
audio_in = gr.Audio(
|
39 |
+
sources=["upload", "microphone"],
|
40 |
+
type="numpy",
|
41 |
+
label="Input audio"
|
42 |
+
)
|
43 |
+
btn = gr.Button("Encode + Decode")
|
44 |
with gr.Column():
|
45 |
+
audio_out = gr.Audio(
|
46 |
+
type="numpy",
|
47 |
+
label="Reconstructed audio"
|
48 |
+
)
|
49 |
|
50 |
+
btn.click(reconstruct, inputs=audio_in, outputs=audio_out)
|
51 |
|
52 |
if __name__ == "__main__":
|
53 |
demo.launch()
|