mrfakename commited on
Commit
a926d98
·
verified ·
1 Parent(s): 574dde7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -26
app.py CHANGED
@@ -3,50 +3,51 @@ import torchaudio
3
  from snac import SNAC
4
  import gradio as gr
5
 
 
 
 
 
6
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
7
- MODEL = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(DEVICE)
8
 
9
  def reconstruct(audio_in):
10
  if audio_in is None:
11
  return None
12
- sr, data = audio_in # (sr, np.ndarray)
13
-
14
- # to tensor [channels, T]
15
- audio = torch.from_numpy(data.T).float()
16
 
17
- # resample to 24k if needed
18
- if sr != 24000:
19
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=24000)
20
- audio = resampler(audio)
21
 
22
- # stereo → mono
23
- if audio.size(0) > 1:
24
- audio = audio.mean(dim=0, keepdim=True)
25
 
26
- # expand to [1,1,T]
27
- audio = audio.unsqueeze(0).to(DEVICE)
28
 
 
29
  with torch.inference_mode():
30
- out = MODEL(audio)
31
- audio_hat = out[0] if isinstance(out,(list,tuple)) else out
32
 
33
  y = audio_hat.squeeze().cpu().numpy()
34
- return (24000, y)
35
 
36
- with gr.Blocks() as demo:
37
- gr.Markdown("## SNAC Audio Reconstructor (24 kHz)")
38
 
39
  with gr.Row():
40
  with gr.Column():
41
- audio_in = gr.Audio(sources=["upload","microphone"],
42
- type="numpy",
43
- label="Input Audio")
44
- btn = gr.Button("Reconstruct")
 
 
45
  with gr.Column():
46
- audio_out = gr.Audio(type="numpy",
47
- label="Reconstructed")
 
 
48
 
49
- btn.click(fn=reconstruct, inputs=audio_in, outputs=audio_out)
50
 
51
  if __name__ == "__main__":
52
  demo.launch()
 
3
  from snac import SNAC
4
  import gradio as gr
5
 
6
+ # pick the right SNAC model for your audio sample rate
7
+ MODEL_NAME = "hubertsiuzdak/snac_24khz"
8
+ SR = 24000
9
+
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model = SNAC.from_pretrained(MODEL_NAME).eval().to(DEVICE)
12
 
13
  def reconstruct(audio_in):
14
  if audio_in is None:
15
  return None
 
 
 
 
16
 
17
+ sr, data = audio_in # data: (T,) or (T,C)
 
 
 
18
 
19
+ # convert stereo → mono
20
+ if data.ndim == 2:
21
+ data = data.mean(axis=1)
22
 
23
+ # turn into torch [1,1,T]
24
+ audio = torch.from_numpy(data).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
25
 
26
+ # run through SNAC
27
  with torch.inference_mode():
28
+ audio_hat, codes = model(audio)
 
29
 
30
  y = audio_hat.squeeze().cpu().numpy()
31
+ return (SR, y)
32
 
33
+ with gr.Blocks(title="SNAC Round-Trip Demo") as demo:
34
+ gr.Markdown("## 🎧 SNAC Audio Reconstructor (minimal!)")
35
 
36
  with gr.Row():
37
  with gr.Column():
38
+ audio_in = gr.Audio(
39
+ sources=["upload", "microphone"],
40
+ type="numpy",
41
+ label="Input audio"
42
+ )
43
+ btn = gr.Button("Encode + Decode")
44
  with gr.Column():
45
+ audio_out = gr.Audio(
46
+ type="numpy",
47
+ label="Reconstructed audio"
48
+ )
49
 
50
+ btn.click(reconstruct, inputs=audio_in, outputs=audio_out)
51
 
52
  if __name__ == "__main__":
53
  demo.launch()