Spaces:
Runtime error
Runtime error
| import torch | |
| import torchaudio | |
| from torchaudio.functional import resample | |
| from transformers import AutoProcessor, EncodecModel | |
| ALL_BANDWIDTHS = [1.1] | |
| class VoilaTokenizer: | |
| def __init__( | |
| self, | |
| model_path="maitrix-org/Voila-Tokenizer", | |
| bandwidth_id=0, | |
| device="cpu", | |
| ): | |
| self.device = torch.device(device) | |
| self.bandwidth = ALL_BANDWIDTHS[bandwidth_id] | |
| self.bandwidth_id = torch.tensor([bandwidth_id], device=device) | |
| self.processor = AutoProcessor.from_pretrained(model_path) | |
| self.model = EncodecModel.from_pretrained(model_path).to(device) | |
| self.sampling_rate = self.processor.sampling_rate | |
| self.model_version = self.model.config.model_version | |
| def encode(self, wav, sr): | |
| wav = torch.tensor(wav, dtype=torch.float32, device=self.device) | |
| if sr != self.processor.sampling_rate: | |
| wav = resample(wav, sr, self.processor.sampling_rate) | |
| sr = self.processor.sampling_rate | |
| if len(wav.shape) == 1: | |
| wav = wav[None, None, :] | |
| elif len(wav.shape) == 2: | |
| assert wav.shape[0] == 1 | |
| wav = wav[None, :] | |
| elif len(wav.shape) == 3: | |
| assert wav.shape[0] == 1 and wav.shape[1] == 1 | |
| # inputs = self.processor(raw_audio=wav, sampling_rate=sr, return_tensors="pt") | |
| encoder_outputs = self.model.encode(wav, bandwidth=self.bandwidth) | |
| return encoder_outputs.audio_codes[0, 0] | |
| def decode(self, audio_codes): | |
| assert len(audio_codes.shape) == 2 | |
| audio_values = self.model.decode(audio_codes[None, None, :, :], [None])[0] | |
| return audio_values[0, 0] | |
| if __name__ == '__main__': | |
| import argparse | |
| import soundfile as sf | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--wav", type=str) | |
| args = parser.parse_args() | |
| wav, sr = torchaudio.load(args.wav) | |
| if len(wav.shape) > 1: | |
| wav = wav[0] | |
| model = VoilaTokenizer(device="cuda") | |
| audio_codes = model.encode(wav, sr) | |
| audio_values = model.decode(audio_codes).cpu().numpy() | |
| tps = audio_codes.shape[-1] / (audio_values.shape[-1] / model.processor.sampling_rate) | |
| print(audio_codes.shape, audio_values.shape, tps) | |
| sf.write("audio_mt.wav", audio_values, model.processor.sampling_rate) | |