mrfakename commited on
Commit
34ffbd5
·
verified ·
1 Parent(s): 17fb016

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -26
app.py CHANGED
@@ -1,38 +1,66 @@
 
 
1
  import torch
2
  import torchaudio
3
  from snac import SNAC
4
- import soundfile as sf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- filename = "/content/en_sample.wav"
7
- audio, sr = torchaudio.load(filename)
8
 
9
- # Resample to 24kHz if necessary
10
- if sr != 24000:
11
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=24000)
12
- audio = resampler(audio)
13
 
14
- # Convert to mono by averaging the channels if the audio is stereo
15
- if audio.size(0) > 1:
16
- audio = torch.mean(audio, dim=0, keepdim=True)
17
 
18
- # Confirm audio is in the shape [1, 1, T] where T is the sequence length
19
- print("Audio size after processing:", audio.size(), audio.shape)
20
 
21
- # Load the SNAC model
22
- model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
 
 
23
 
24
- # Move to CUDA if available
25
- if torch.cuda.is_available():
26
- model = model.cuda()
27
- audio = audio.cuda()
 
 
 
 
28
 
29
- audio = torch.unsqueeze(audio, 0)
30
- # Encode and decode the audio with SNAC
31
- with torch.inference_mode():
32
- audio_hat, _, codes, _, _ = model(audio)
 
33
 
34
- # Move the tensor back to CPU for saving and convert back to numpy
35
- audio_hat = audio_hat.cpu().detach().numpy()
36
 
37
- # Save the reconstructed audio file
38
- sf.write('reconstructed_audio.wav', audio_hat.squeeze(), 24000) # Use .squeeze() to remove single-dimensional entries
 
1
+ # pip install gradio torch torchaudio soundfile snac
2
+
3
  import torch
4
  import torchaudio
5
  from snac import SNAC
6
+ import gradio as gr
7
+
8
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # load SNAC once
11
+ MODEL = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(DEVICE)
12
+
13
+ def reconstruct(audio_in):
14
+ """
15
+ audio_in is (sample_rate:int, data:np.ndarray) from gr.Audio(type="numpy")
16
+ returns (24000, np.ndarray)
17
+ """
18
+ if audio_in is None:
19
+ return None
20
+
21
+ sr, data = audio_in # data: (T,) or (T, C)
22
+
23
+ # convert to mono if stereo
24
+ if data.ndim == 2 and data.shape[1] > 1:
25
+ data = data.mean(axis=1)
26
+
27
+ x = torch.from_numpy(data).float().unsqueeze(0) # [1, T]
28
+
29
+ if sr != 24000:
30
+ x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=24000)
31
+
32
+ x = x.unsqueeze(0).to(DEVICE) # [1, 1, T]
33
 
34
+ with torch.inference_mode():
35
+ y_hat, _, _, _, _ = MODEL(x) # [1, 1, T]
36
 
37
+ y = y_hat.squeeze(0).squeeze(0).detach().cpu()
38
+ y = torch.clamp(y, -1.0, 1.0) # safety clamp
 
 
39
 
40
+ return (24000, y.numpy())
 
 
41
 
 
 
42
 
43
+ with gr.Blocks(title="SNAC Audio Reconstructor") as demo:
44
+ gr.Markdown("## 🎵 SNAC Audio Reconstructor (24kHz)")
45
+ gr.Markdown("Upload or record audio. It’ll get resampled to 24kHz, "
46
+ "mono-ized, then passed through SNAC for reconstruction.")
47
 
48
+ with gr.Row():
49
+ with gr.Column():
50
+ audio_in = gr.Audio(
51
+ sources=["upload", "microphone"],
52
+ type="numpy",
53
+ label="Input audio"
54
+ )
55
+ btn = gr.Button("Reconstruct")
56
 
57
+ with gr.Column():
58
+ audio_out = gr.Audio(
59
+ type="numpy",
60
+ label="Reconstructed audio (24kHz)"
61
+ )
62
 
63
+ btn.click(fn=reconstruct, inputs=audio_in, outputs=audio_out)
 
64
 
65
+ if __name__ == "__main__":
66
+ demo.launch()