Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,59 +1,49 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
import torchaudio
|
3 |
-
from snac import SNAC
|
4 |
import gradio as gr
|
5 |
-
|
6 |
-
# choose your SNAC model + target sample rate
|
7 |
-
MODEL_NAME = "hubertsiuzdak/snac_24khz"
|
8 |
-
TARGET_SR = 24000
|
9 |
|
10 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
-
|
|
|
12 |
|
13 |
-
def
|
14 |
if audio_in is None:
|
15 |
return None
|
16 |
|
17 |
-
sr, data = audio_in # data: (T,) or (T,C)
|
18 |
|
19 |
-
#
|
20 |
if data.ndim == 2:
|
21 |
data = data.mean(axis=1)
|
22 |
|
23 |
-
# torchify
|
24 |
-
|
25 |
|
26 |
-
# resample to target SR
|
27 |
if sr != TARGET_SR:
|
28 |
-
|
29 |
-
|
30 |
-
# expand to [B,1,T]
|
31 |
-
audio = audio.unsqueeze(0).to(DEVICE)
|
32 |
|
|
|
|
|
33 |
with torch.inference_mode():
|
34 |
-
|
|
|
35 |
|
36 |
-
y =
|
37 |
return (TARGET_SR, y)
|
38 |
|
39 |
-
with gr.Blocks(title="SNAC
|
40 |
-
gr.Markdown("## 🎧 SNAC
|
41 |
-
|
42 |
with gr.Row():
|
43 |
with gr.Column():
|
44 |
-
audio_in = gr.Audio(
|
45 |
-
|
46 |
-
type="numpy",
|
47 |
-
label="Input audio"
|
48 |
-
)
|
49 |
-
btn = gr.Button("Encode + Decode")
|
50 |
with gr.Column():
|
51 |
-
audio_out = gr.Audio(
|
52 |
-
|
53 |
-
label="Reconstructed audio"
|
54 |
-
)
|
55 |
-
|
56 |
-
btn.click(reconstruct, inputs=audio_in, outputs=audio_out)
|
57 |
|
58 |
if __name__ == "__main__":
|
59 |
demo.launch()
|
|
|
1 |
+
# pip install gradio torch torchaudio snac
|
2 |
+
|
3 |
import torch
|
4 |
import torchaudio
|
|
|
5 |
import gradio as gr
|
6 |
+
from snac import SNAC
|
|
|
|
|
|
|
7 |
|
8 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
+
TARGET_SR = 32000 # using the 32 kHz model per your example
|
10 |
+
MODEL = SNAC.from_pretrained("hubertsiuzdak/snac_32khz").eval().to(DEVICE)
|
11 |
|
12 |
+
def encode_then_decode(audio_in):
|
13 |
if audio_in is None:
|
14 |
return None
|
15 |
|
16 |
+
sr, data = audio_in # data: (T,) mono or (T, C) stereo
|
17 |
|
18 |
+
# mono-ize if needed
|
19 |
if data.ndim == 2:
|
20 |
data = data.mean(axis=1)
|
21 |
|
22 |
+
# torchify to [1, T]
|
23 |
+
x = torch.from_numpy(data).float().unsqueeze(0)
|
24 |
|
25 |
+
# resample to model's target SR
|
26 |
if sr != TARGET_SR:
|
27 |
+
x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=TARGET_SR)
|
|
|
|
|
|
|
28 |
|
29 |
+
# expand to [B, 1, T] then encode->decode
|
30 |
+
x = x.unsqueeze(0).to(DEVICE) # [1, 1, T]
|
31 |
with torch.inference_mode():
|
32 |
+
codes = MODEL.encode(x)
|
33 |
+
y = MODEL.decode(codes) # [1, 1, T]
|
34 |
|
35 |
+
y = y.squeeze().detach().cpu().numpy()
|
36 |
return (TARGET_SR, y)
|
37 |
|
38 |
+
with gr.Blocks(title="SNAC Encode→Decode (Simple)") as demo:
|
39 |
+
gr.Markdown("## 🎧 SNAC Encode → Decode (32 kHz)\nResample → `encode()` → `decode()` — that’s it.")
|
|
|
40 |
with gr.Row():
|
41 |
with gr.Column():
|
42 |
+
audio_in = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Input audio")
|
43 |
+
run = gr.Button("Encode + Decode")
|
|
|
|
|
|
|
|
|
44 |
with gr.Column():
|
45 |
+
audio_out = gr.Audio(type="numpy", label="Reconstructed (32 kHz)")
|
46 |
+
run.click(encode_then_decode, inputs=audio_in, outputs=audio_out)
|
|
|
|
|
|
|
|
|
47 |
|
48 |
if __name__ == "__main__":
|
49 |
demo.launch()
|