mrfakename commited on
Commit
574dde7
·
verified ·
1 Parent(s): 55781cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -33
app.py CHANGED
@@ -1,60 +1,50 @@
1
- # pip install gradio torch torchaudio soundfile snac
2
-
3
  import torch
4
  import torchaudio
5
  from snac import SNAC
6
  import gradio as gr
7
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
-
10
- # load SNAC once
11
  MODEL = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(DEVICE)
12
 
13
  def reconstruct(audio_in):
14
  if audio_in is None:
15
  return None
 
16
 
17
- sr, data = audio_in
18
-
19
- if data.ndim == 2 and data.shape[1] > 1:
20
- data = data.mean(axis=1)
21
-
22
- x = torch.from_numpy(data).float().unsqueeze(0) # [1, T]
23
 
 
24
  if sr != 24000:
25
- x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=24000)
 
26
 
27
- x = x.unsqueeze(0).to(DEVICE) # [1, 1, T]
 
 
28
 
29
- with torch.inference_mode():
30
- out = MODEL(x)
31
- audio_hat = out[0] if isinstance(out, (list, tuple)) else out
32
-
33
- y = audio_hat.squeeze(0).squeeze(0).detach().cpu()
34
- y = torch.clamp(y, -1.0, 1.0)
35
 
36
- return (24000, y.numpy())
 
 
37
 
 
 
38
 
39
- with gr.Blocks(title="SNAC Audio Reconstructor") as demo:
40
- gr.Markdown("## 🎵 SNAC Audio Reconstructor (24kHz)")
41
- gr.Markdown("Upload or record audio. It’ll get resampled to 24kHz, "
42
- "mono-ized, then passed through SNAC for reconstruction.")
43
 
44
  with gr.Row():
45
  with gr.Column():
46
- audio_in = gr.Audio(
47
- sources=["upload", "microphone"],
48
- type="numpy",
49
- label="Input audio"
50
- )
51
  btn = gr.Button("Reconstruct")
52
-
53
  with gr.Column():
54
- audio_out = gr.Audio(
55
- type="numpy",
56
- label="Reconstructed audio (24kHz)"
57
- )
58
 
59
  btn.click(fn=reconstruct, inputs=audio_in, outputs=audio_out)
60
 
 
 
 
1
  import torch
2
  import torchaudio
3
  from snac import SNAC
4
  import gradio as gr
5
 
6
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
7
  MODEL = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(DEVICE)
8
 
9
  def reconstruct(audio_in):
10
  if audio_in is None:
11
  return None
12
+ sr, data = audio_in # (sr, np.ndarray)
13
 
14
+ # to tensor [channels, T]
15
+ audio = torch.from_numpy(data.T).float()
 
 
 
 
16
 
17
+ # resample to 24k if needed
18
  if sr != 24000:
19
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=24000)
20
+ audio = resampler(audio)
21
 
22
+ # stereo mono
23
+ if audio.size(0) > 1:
24
+ audio = audio.mean(dim=0, keepdim=True)
25
 
26
+ # expand to [1,1,T]
27
+ audio = audio.unsqueeze(0).to(DEVICE)
 
 
 
 
28
 
29
+ with torch.inference_mode():
30
+ out = MODEL(audio)
31
+ audio_hat = out[0] if isinstance(out,(list,tuple)) else out
32
 
33
+ y = audio_hat.squeeze().cpu().numpy()
34
+ return (24000, y)
35
 
36
+ with gr.Blocks() as demo:
37
+ gr.Markdown("## SNAC Audio Reconstructor (24 kHz)")
 
 
38
 
39
  with gr.Row():
40
  with gr.Column():
41
+ audio_in = gr.Audio(sources=["upload","microphone"],
42
+ type="numpy",
43
+ label="Input Audio")
 
 
44
  btn = gr.Button("Reconstruct")
 
45
  with gr.Column():
46
+ audio_out = gr.Audio(type="numpy",
47
+ label="Reconstructed")
 
 
48
 
49
  btn.click(fn=reconstruct, inputs=audio_in, outputs=audio_out)
50