Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import time | |
| from typing import Dict | |
| import pathlib | |
| import librosa | |
| import numpy as np | |
| import soundfile | |
| import torch | |
| import torch.nn as nn | |
| from bytesep.models.lightning_modules import get_model_class | |
| from bytesep.utils import read_yaml | |
| class Separator: | |
| def __init__( | |
| self, model: nn.Module, segment_samples: int, batch_size: int, device: str | |
| ): | |
| r"""Separate to separate an audio clip into a target source. | |
| Args: | |
| model: nn.Module, trained model | |
| segment_samples: int, length of segments to be input to a model, e.g., 44100*30 | |
| batch_size, int, e.g., 12 | |
| device: str, e.g., 'cuda' | |
| """ | |
| self.model = model | |
| self.segment_samples = segment_samples | |
| self.batch_size = batch_size | |
| self.device = device | |
| def separate(self, input_dict: Dict) -> np.array: | |
| r"""Separate an audio clip into a target source. | |
| Args: | |
| input_dict: dict, e.g., { | |
| waveform: (channels_num, audio_samples), | |
| ..., | |
| } | |
| Returns: | |
| sep_audio: (channels_num, audio_samples) | (target_sources_num, channels_num, audio_samples) | |
| """ | |
| audio = input_dict['waveform'] | |
| audio_samples = audio.shape[-1] | |
| # Pad the audio with zero in the end so that the length of audio can be | |
| # evenly divided by segment_samples. | |
| audio = self.pad_audio(audio) | |
| # Enframe long audio into segments. | |
| segments = self.enframe(audio, self.segment_samples) | |
| # (segments_num, channels_num, segment_samples) | |
| segments_input_dict = {'waveform': segments} | |
| if 'condition' in input_dict.keys(): | |
| segments_num = len(segments) | |
| segments_input_dict['condition'] = np.tile( | |
| input_dict['condition'][None, :], (segments_num, 1) | |
| ) | |
| # (batch_size, segments_num) | |
| # Separate in mini-batches. | |
| sep_segments = self._forward_in_mini_batches( | |
| self.model, segments_input_dict, self.batch_size | |
| )['waveform'] | |
| # (segments_num, channels_num, segment_samples) | |
| # Deframe segments into long audio. | |
| sep_audio = self.deframe(sep_segments) | |
| # (channels_num, padded_audio_samples) | |
| sep_audio = sep_audio[:, 0:audio_samples] | |
| # (channels_num, audio_samples) | |
| return sep_audio | |
| def pad_audio(self, audio: np.array) -> np.array: | |
| r"""Pad the audio with zero in the end so that the length of audio can | |
| be evenly divided by segment_samples. | |
| Args: | |
| audio: (channels_num, audio_samples) | |
| Returns: | |
| padded_audio: (channels_num, audio_samples) | |
| """ | |
| channels_num, audio_samples = audio.shape | |
| # Number of segments | |
| segments_num = int(np.ceil(audio_samples / self.segment_samples)) | |
| pad_samples = segments_num * self.segment_samples - audio_samples | |
| padded_audio = np.concatenate( | |
| (audio, np.zeros((channels_num, pad_samples))), axis=1 | |
| ) | |
| # (channels_num, padded_audio_samples) | |
| return padded_audio | |
| def enframe(self, audio: np.array, segment_samples: int) -> np.array: | |
| r"""Enframe long audio into segments. | |
| Args: | |
| audio: (channels_num, audio_samples) | |
| segment_samples: int | |
| Returns: | |
| segments: (segments_num, channels_num, segment_samples) | |
| """ | |
| audio_samples = audio.shape[1] | |
| assert audio_samples % segment_samples == 0 | |
| hop_samples = segment_samples // 2 | |
| segments = [] | |
| pointer = 0 | |
| while pointer + segment_samples <= audio_samples: | |
| segments.append(audio[:, pointer : pointer + segment_samples]) | |
| pointer += hop_samples | |
| segments = np.array(segments) | |
| return segments | |
| def deframe(self, segments: np.array) -> np.array: | |
| r"""Deframe segments into long audio. | |
| Args: | |
| segments: (segments_num, channels_num, segment_samples) | |
| Returns: | |
| output: (channels_num, audio_samples) | |
| """ | |
| (segments_num, _, segment_samples) = segments.shape | |
| if segments_num == 1: | |
| return segments[0] | |
| assert self._is_integer(segment_samples * 0.25) | |
| assert self._is_integer(segment_samples * 0.75) | |
| output = [] | |
| output.append(segments[0, :, 0 : int(segment_samples * 0.75)]) | |
| for i in range(1, segments_num - 1): | |
| output.append( | |
| segments[ | |
| i, :, int(segment_samples * 0.25) : int(segment_samples * 0.75) | |
| ] | |
| ) | |
| output.append(segments[-1, :, int(segment_samples * 0.25) :]) | |
| output = np.concatenate(output, axis=-1) | |
| return output | |
| def _is_integer(self, x: float) -> bool: | |
| if x - int(x) < 1e-10: | |
| return True | |
| else: | |
| return False | |
| def _forward_in_mini_batches( | |
| self, model: nn.Module, segments_input_dict: Dict, batch_size: int | |
| ) -> Dict: | |
| r"""Forward data to model in mini-batch. | |
| Args: | |
| model: nn.Module | |
| segments_input_dict: dict, e.g., { | |
| 'waveform': (segments_num, channels_num, segment_samples), | |
| ..., | |
| } | |
| batch_size: int | |
| Returns: | |
| output_dict: dict, e.g. { | |
| 'waveform': (segments_num, channels_num, segment_samples), | |
| } | |
| """ | |
| output_dict = {} | |
| pointer = 0 | |
| segments_num = len(segments_input_dict['waveform']) | |
| while True: | |
| if pointer >= segments_num: | |
| break | |
| batch_input_dict = {} | |
| for key in segments_input_dict.keys(): | |
| batch_input_dict[key] = torch.Tensor( | |
| segments_input_dict[key][pointer : pointer + batch_size] | |
| ).to(self.device) | |
| pointer += batch_size | |
| with torch.no_grad(): | |
| model.eval() | |
| batch_output_dict = model(batch_input_dict) | |
| for key in batch_output_dict.keys(): | |
| self._append_to_dict( | |
| output_dict, key, batch_output_dict[key].data.cpu().numpy() | |
| ) | |
| for key in output_dict.keys(): | |
| output_dict[key] = np.concatenate(output_dict[key], axis=0) | |
| return output_dict | |
| def _append_to_dict(self, dict, key, value): | |
| if key in dict.keys(): | |
| dict[key].append(value) | |
| else: | |
| dict[key] = [value] | |
| class SeparatorWrapper: | |
| def __init__( | |
| self, source_type='vocals', model=None, checkpoint_path=None, device='cuda' | |
| ): | |
| input_channels = 2 | |
| target_sources_num = 1 | |
| model_type = "ResUNet143_Subbandtime" | |
| segment_samples = 44100 * 10 | |
| batch_size = 1 | |
| self.checkpoint_path = self.download_checkpoints(checkpoint_path, source_type) | |
| if device == 'cuda' and torch.cuda.is_available(): | |
| self.device = 'cuda' | |
| else: | |
| self.device = 'cpu' | |
| # Get model class. | |
| Model = get_model_class(model_type) | |
| # Create model. | |
| self.model = Model( | |
| input_channels=input_channels, target_sources_num=target_sources_num | |
| ) | |
| # Load checkpoint. | |
| checkpoint = torch.load(self.checkpoint_path, map_location='cpu') | |
| self.model.load_state_dict(checkpoint["model"]) | |
| # Move model to device. | |
| self.model.to(self.device) | |
| # Create separator. | |
| self.separator = Separator( | |
| model=self.model, | |
| segment_samples=segment_samples, | |
| batch_size=batch_size, | |
| device=self.device, | |
| ) | |
| def download_checkpoints(self, checkpoint_path, source_type): | |
| if source_type == "vocals": | |
| checkpoint_bare_name = "resunet143_subbtandtime_vocals_8.8dB_350k_steps" | |
| elif source_type == "accompaniment": | |
| checkpoint_bare_name = ( | |
| "resunet143_subbtandtime_accompaniment_16.4dB_350k_steps.pth" | |
| ) | |
| else: | |
| raise NotImplementedError | |
| if not checkpoint_path: | |
| checkpoint_path = '{}/bytesep_data/{}.pth'.format( | |
| str(pathlib.Path.home()), checkpoint_bare_name | |
| ) | |
| print('Checkpoint path: {}'.format(checkpoint_path)) | |
| if ( | |
| not os.path.exists(checkpoint_path) | |
| or os.path.getsize(checkpoint_path) < 4e8 | |
| ): | |
| os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) | |
| zenodo_dir = "https://zenodo.org/record/5507029/files" | |
| zenodo_path = os.path.join( | |
| zenodo_dir, "{}?download=1".format(checkpoint_bare_name) | |
| ) | |
| os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path)) | |
| return checkpoint_path | |
| def separate(self, audio): | |
| input_dict = {'waveform': audio} | |
| sep_wav = self.separator.separate(input_dict) | |
| return sep_wav | |
| def inference(args): | |
| # Need to use torch.distributed if models contain inplace_abn.abn.InPlaceABNSync. | |
| import torch.distributed as dist | |
| dist.init_process_group( | |
| 'gloo', init_method='file:///tmp/somefile', rank=0, world_size=1 | |
| ) | |
| # Arguments & parameters | |
| config_yaml = args.config_yaml | |
| checkpoint_path = args.checkpoint_path | |
| audio_path = args.audio_path | |
| output_path = args.output_path | |
| device = ( | |
| torch.device('cuda') | |
| if args.cuda and torch.cuda.is_available() | |
| else torch.device('cpu') | |
| ) | |
| configs = read_yaml(config_yaml) | |
| sample_rate = configs['train']['sample_rate'] | |
| input_channels = configs['train']['channels'] | |
| target_source_types = configs['train']['target_source_types'] | |
| target_sources_num = len(target_source_types) | |
| model_type = configs['train']['model_type'] | |
| segment_samples = int(30 * sample_rate) | |
| batch_size = 1 | |
| print("Using {} for separating ..".format(device)) | |
| # paths | |
| if os.path.dirname(output_path) != "": | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| # Get model class. | |
| Model = get_model_class(model_type) | |
| # Create model. | |
| model = Model(input_channels=input_channels, target_sources_num=target_sources_num) | |
| # Load checkpoint. | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| model.load_state_dict(checkpoint["model"]) | |
| # Move model to device. | |
| model.to(device) | |
| # Create separator. | |
| separator = Separator( | |
| model=model, | |
| segment_samples=segment_samples, | |
| batch_size=batch_size, | |
| device=device, | |
| ) | |
| # Load audio. | |
| audio, _ = librosa.load(audio_path, sr=sample_rate, mono=False) | |
| # audio = audio[None, :] | |
| input_dict = {'waveform': audio} | |
| # Separate | |
| separate_time = time.time() | |
| sep_wav = separator.separate(input_dict) | |
| # (channels_num, audio_samples) | |
| print('Separate time: {:.3f} s'.format(time.time() - separate_time)) | |
| # Write out separated audio. | |
| soundfile.write(file='_zz.wav', data=sep_wav.T, samplerate=sample_rate) | |
| os.system("ffmpeg -y -loglevel panic -i _zz.wav {}".format(output_path)) | |
| print('Write out to {}'.format(output_path)) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="") | |
| parser.add_argument("--config_yaml", type=str, required=True) | |
| parser.add_argument("--checkpoint_path", type=str, required=True) | |
| parser.add_argument("--audio_path", type=str, required=True) | |
| parser.add_argument("--output_path", type=str, required=True) | |
| parser.add_argument("--cuda", action='store_true', default=True) | |
| args = parser.parse_args() | |
| inference(args) | |