mrfakename commited on
Commit
ca49a8a
·
verified ·
1 Parent(s): a926d98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -3,9 +3,9 @@ import torchaudio
3
  from snac import SNAC
4
  import gradio as gr
5
 
6
- # pick the right SNAC model for your audio sample rate
7
  MODEL_NAME = "hubertsiuzdak/snac_24khz"
8
- SR = 24000
9
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  model = SNAC.from_pretrained(MODEL_NAME).eval().to(DEVICE)
@@ -20,18 +20,24 @@ def reconstruct(audio_in):
20
  if data.ndim == 2:
21
  data = data.mean(axis=1)
22
 
23
- # turn into torch [1,1,T]
24
- audio = torch.from_numpy(data).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
 
 
 
 
 
 
 
25
 
26
- # run through SNAC
27
  with torch.inference_mode():
28
  audio_hat, codes = model(audio)
29
 
30
  y = audio_hat.squeeze().cpu().numpy()
31
- return (SR, y)
32
 
33
  with gr.Blocks(title="SNAC Round-Trip Demo") as demo:
34
- gr.Markdown("## 🎧 SNAC Audio Reconstructor (minimal!)")
35
 
36
  with gr.Row():
37
  with gr.Column():
 
3
  from snac import SNAC
4
  import gradio as gr
5
 
6
+ # choose your SNAC model + target sample rate
7
  MODEL_NAME = "hubertsiuzdak/snac_24khz"
8
+ TARGET_SR = 24000
9
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  model = SNAC.from_pretrained(MODEL_NAME).eval().to(DEVICE)
 
20
  if data.ndim == 2:
21
  data = data.mean(axis=1)
22
 
23
+ # torchify
24
+ audio = torch.from_numpy(data).float().unsqueeze(0) # [1,T]
25
+
26
+ # resample to target SR
27
+ if sr != TARGET_SR:
28
+ audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=TARGET_SR)
29
+
30
+ # expand to [B,1,T]
31
+ audio = audio.unsqueeze(0).to(DEVICE)
32
 
 
33
  with torch.inference_mode():
34
  audio_hat, codes = model(audio)
35
 
36
  y = audio_hat.squeeze().cpu().numpy()
37
+ return (TARGET_SR, y)
38
 
39
  with gr.Blocks(title="SNAC Round-Trip Demo") as demo:
40
+ gr.Markdown("## 🎧 SNAC Audio Reconstructor (with resampling)")
41
 
42
  with gr.Row():
43
  with gr.Column():