mrfakename commited on
Commit
17fb016
·
verified ·
1 Parent(s): 5196160

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from snac import SNAC
4
+ import soundfile as sf
5
+
6
+ filename = "/content/en_sample.wav"
7
+ audio, sr = torchaudio.load(filename)
8
+
9
+ # Resample to 24kHz if necessary
10
+ if sr != 24000:
11
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=24000)
12
+ audio = resampler(audio)
13
+
14
+ # Convert to mono by averaging the channels if the audio is stereo
15
+ if audio.size(0) > 1:
16
+ audio = torch.mean(audio, dim=0, keepdim=True)
17
+
18
+ # Confirm audio is in the shape [1, 1, T] where T is the sequence length
19
+ print("Audio size after processing:", audio.size(), audio.shape)
20
+
21
+ # Load the SNAC model
22
+ model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
23
+
24
+ # Move to CUDA if available
25
+ if torch.cuda.is_available():
26
+ model = model.cuda()
27
+ audio = audio.cuda()
28
+
29
+ audio = torch.unsqueeze(audio, 0)
30
+ # Encode and decode the audio with SNAC
31
+ with torch.inference_mode():
32
+ audio_hat, _, codes, _, _ = model(audio)
33
+
34
+ # Move the tensor back to CPU for saving and convert back to numpy
35
+ audio_hat = audio_hat.cpu().detach().numpy()
36
+
37
+ # Save the reconstructed audio file
38
+ sf.write('reconstructed_audio.wav', audio_hat.squeeze(), 24000) # Use .squeeze() to remove single-dimensional entries