Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,38 +1,66 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
import torchaudio
|
3 |
from snac import SNAC
|
4 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=24000)
|
12 |
-
audio = resampler(audio)
|
13 |
|
14 |
-
|
15 |
-
if audio.size(0) > 1:
|
16 |
-
audio = torch.mean(audio, dim=0, keepdim=True)
|
17 |
|
18 |
-
# Confirm audio is in the shape [1, 1, T] where T is the sequence length
|
19 |
-
print("Audio size after processing:", audio.size(), audio.shape)
|
20 |
|
21 |
-
|
22 |
-
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
33 |
|
34 |
-
|
35 |
-
audio_hat = audio_hat.cpu().detach().numpy()
|
36 |
|
37 |
-
|
38 |
-
|
|
|
1 |
+
# pip install gradio torch torchaudio soundfile snac
|
2 |
+
|
3 |
import torch
|
4 |
import torchaudio
|
5 |
from snac import SNAC
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
+
|
10 |
+
# load SNAC once
|
11 |
+
MODEL = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(DEVICE)
|
12 |
+
|
13 |
+
def reconstruct(audio_in):
|
14 |
+
"""
|
15 |
+
audio_in is (sample_rate:int, data:np.ndarray) from gr.Audio(type="numpy")
|
16 |
+
returns (24000, np.ndarray)
|
17 |
+
"""
|
18 |
+
if audio_in is None:
|
19 |
+
return None
|
20 |
+
|
21 |
+
sr, data = audio_in # data: (T,) or (T, C)
|
22 |
+
|
23 |
+
# convert to mono if stereo
|
24 |
+
if data.ndim == 2 and data.shape[1] > 1:
|
25 |
+
data = data.mean(axis=1)
|
26 |
+
|
27 |
+
x = torch.from_numpy(data).float().unsqueeze(0) # [1, T]
|
28 |
+
|
29 |
+
if sr != 24000:
|
30 |
+
x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=24000)
|
31 |
+
|
32 |
+
x = x.unsqueeze(0).to(DEVICE) # [1, 1, T]
|
33 |
|
34 |
+
with torch.inference_mode():
|
35 |
+
y_hat, _, _, _, _ = MODEL(x) # [1, 1, T]
|
36 |
|
37 |
+
y = y_hat.squeeze(0).squeeze(0).detach().cpu()
|
38 |
+
y = torch.clamp(y, -1.0, 1.0) # safety clamp
|
|
|
|
|
39 |
|
40 |
+
return (24000, y.numpy())
|
|
|
|
|
41 |
|
|
|
|
|
42 |
|
43 |
+
with gr.Blocks(title="SNAC Audio Reconstructor") as demo:
|
44 |
+
gr.Markdown("## 🎵 SNAC Audio Reconstructor (24kHz)")
|
45 |
+
gr.Markdown("Upload or record audio. It’ll get resampled to 24kHz, "
|
46 |
+
"mono-ized, then passed through SNAC for reconstruction.")
|
47 |
|
48 |
+
with gr.Row():
|
49 |
+
with gr.Column():
|
50 |
+
audio_in = gr.Audio(
|
51 |
+
sources=["upload", "microphone"],
|
52 |
+
type="numpy",
|
53 |
+
label="Input audio"
|
54 |
+
)
|
55 |
+
btn = gr.Button("Reconstruct")
|
56 |
|
57 |
+
with gr.Column():
|
58 |
+
audio_out = gr.Audio(
|
59 |
+
type="numpy",
|
60 |
+
label="Reconstructed audio (24kHz)"
|
61 |
+
)
|
62 |
|
63 |
+
btn.click(fn=reconstruct, inputs=audio_in, outputs=audio_out)
|
|
|
64 |
|
65 |
+
if __name__ == "__main__":
|
66 |
+
demo.launch()
|