mrfakename commited on
Commit
9e72440
·
verified ·
1 Parent(s): ca49a8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -33
app.py CHANGED
@@ -1,59 +1,49 @@
 
 
1
  import torch
2
  import torchaudio
3
- from snac import SNAC
4
  import gradio as gr
5
-
6
- # choose your SNAC model + target sample rate
7
- MODEL_NAME = "hubertsiuzdak/snac_24khz"
8
- TARGET_SR = 24000
9
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
- model = SNAC.from_pretrained(MODEL_NAME).eval().to(DEVICE)
 
12
 
13
- def reconstruct(audio_in):
14
  if audio_in is None:
15
  return None
16
 
17
- sr, data = audio_in # data: (T,) or (T,C)
18
 
19
- # convert stereo → mono
20
  if data.ndim == 2:
21
  data = data.mean(axis=1)
22
 
23
- # torchify
24
- audio = torch.from_numpy(data).float().unsqueeze(0) # [1,T]
25
 
26
- # resample to target SR
27
  if sr != TARGET_SR:
28
- audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=TARGET_SR)
29
-
30
- # expand to [B,1,T]
31
- audio = audio.unsqueeze(0).to(DEVICE)
32
 
 
 
33
  with torch.inference_mode():
34
- audio_hat, codes = model(audio)
 
35
 
36
- y = audio_hat.squeeze().cpu().numpy()
37
  return (TARGET_SR, y)
38
 
39
- with gr.Blocks(title="SNAC Round-Trip Demo") as demo:
40
- gr.Markdown("## 🎧 SNAC Audio Reconstructor (with resampling)")
41
-
42
  with gr.Row():
43
  with gr.Column():
44
- audio_in = gr.Audio(
45
- sources=["upload", "microphone"],
46
- type="numpy",
47
- label="Input audio"
48
- )
49
- btn = gr.Button("Encode + Decode")
50
  with gr.Column():
51
- audio_out = gr.Audio(
52
- type="numpy",
53
- label="Reconstructed audio"
54
- )
55
-
56
- btn.click(reconstruct, inputs=audio_in, outputs=audio_out)
57
 
58
  if __name__ == "__main__":
59
  demo.launch()
 
1
+ # pip install gradio torch torchaudio snac
2
+
3
  import torch
4
  import torchaudio
 
5
  import gradio as gr
6
+ from snac import SNAC
 
 
 
7
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+ TARGET_SR = 32000 # using the 32 kHz model per your example
10
+ MODEL = SNAC.from_pretrained("hubertsiuzdak/snac_32khz").eval().to(DEVICE)
11
 
12
+ def encode_then_decode(audio_in):
13
  if audio_in is None:
14
  return None
15
 
16
+ sr, data = audio_in # data: (T,) mono or (T, C) stereo
17
 
18
+ # mono-ize if needed
19
  if data.ndim == 2:
20
  data = data.mean(axis=1)
21
 
22
+ # torchify to [1, T]
23
+ x = torch.from_numpy(data).float().unsqueeze(0)
24
 
25
+ # resample to model's target SR
26
  if sr != TARGET_SR:
27
+ x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=TARGET_SR)
 
 
 
28
 
29
+ # expand to [B, 1, T] then encode->decode
30
+ x = x.unsqueeze(0).to(DEVICE) # [1, 1, T]
31
  with torch.inference_mode():
32
+ codes = MODEL.encode(x)
33
+ y = MODEL.decode(codes) # [1, 1, T]
34
 
35
+ y = y.squeeze().detach().cpu().numpy()
36
  return (TARGET_SR, y)
37
 
38
+ with gr.Blocks(title="SNAC Encode→Decode (Simple)") as demo:
39
+ gr.Markdown("## 🎧 SNAC Encode Decode (32 kHz)\nResample → `encode()` → `decode()` — that’s it.")
 
40
  with gr.Row():
41
  with gr.Column():
42
+ audio_in = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Input audio")
43
+ run = gr.Button("Encode + Decode")
 
 
 
 
44
  with gr.Column():
45
+ audio_out = gr.Audio(type="numpy", label="Reconstructed (32 kHz)")
46
+ run.click(encode_then_decode, inputs=audio_in, outputs=audio_out)
 
 
 
 
47
 
48
  if __name__ == "__main__":
49
  demo.launch()