Spaces:
Sleeping
Sleeping
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torchaudio | |
| import json | |
| import os | |
| import numpy as np | |
| import librosa | |
| from torch.nn.utils.rnn import pad_sequence | |
| from modules import whisper_extractor as whisper | |
| class TorchaudioDataset(torch.utils.data.Dataset): | |
| def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): | |
| """ | |
| Args: | |
| cfg: config | |
| dataset: dataset name | |
| """ | |
| assert isinstance(dataset, str) | |
| self.sr = sr | |
| self.cfg = cfg | |
| if metadata is None: | |
| self.train_metadata_path = os.path.join( | |
| cfg.preprocess.processed_dir, dataset, cfg.preprocess.train_file | |
| ) | |
| self.valid_metadata_path = os.path.join( | |
| cfg.preprocess.processed_dir, dataset, cfg.preprocess.valid_file | |
| ) | |
| self.metadata = self.get_metadata() | |
| else: | |
| self.metadata = metadata | |
| if accelerator is not None: | |
| self.device = accelerator.device | |
| elif torch.cuda.is_available(): | |
| self.device = torch.device("cuda") | |
| else: | |
| self.device = torch.device("cpu") | |
| def get_metadata(self): | |
| metadata = [] | |
| with open(self.train_metadata_path, "r", encoding="utf-8") as t: | |
| metadata.extend(json.load(t)) | |
| with open(self.valid_metadata_path, "r", encoding="utf-8") as v: | |
| metadata.extend(json.load(v)) | |
| return metadata | |
| def __len__(self): | |
| return len(self.metadata) | |
| def __getitem__(self, index): | |
| utt_info = self.metadata[index] | |
| wav_path = utt_info["Path"] | |
| wav, sr = torchaudio.load(wav_path) | |
| # resample | |
| if sr != self.sr: | |
| wav = torchaudio.functional.resample(wav, sr, self.sr) | |
| # downmixing | |
| if wav.shape[0] > 1: | |
| wav = torch.mean(wav, dim=0, keepdim=True) | |
| assert wav.shape[0] == 1 | |
| wav = wav.squeeze(0) | |
| # record the length of wav without padding | |
| length = wav.shape[0] | |
| # wav: (T) | |
| return utt_info, wav, length | |
| class LibrosaDataset(TorchaudioDataset): | |
| def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): | |
| super().__init__(cfg, dataset, sr, accelerator, metadata) | |
| def __getitem__(self, index): | |
| utt_info = self.metadata[index] | |
| wav_path = utt_info["Path"] | |
| wav, _ = librosa.load(wav_path, sr=self.sr) | |
| # wav: (T) | |
| wav = torch.from_numpy(wav) | |
| # record the length of wav without padding | |
| length = wav.shape[0] | |
| return utt_info, wav, length | |
| class FFmpegDataset(TorchaudioDataset): | |
| def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): | |
| super().__init__(cfg, dataset, sr, accelerator, metadata) | |
| def __getitem__(self, index): | |
| utt_info = self.metadata[index] | |
| wav_path = utt_info["Path"] | |
| # wav: (T,) | |
| wav = whisper.load_audio(wav_path) # sr = 16000 | |
| # convert to torch tensor | |
| wav = torch.from_numpy(wav) | |
| # record the length of wav without padding | |
| length = wav.shape[0] | |
| return utt_info, wav, length | |
| def collate_batch(batch_list): | |
| """ | |
| Args: | |
| batch_list: list of (metadata, wav, length) | |
| """ | |
| metadata = [item[0] for item in batch_list] | |
| # wavs: (B, T) | |
| wavs = pad_sequence([item[1] for item in batch_list], batch_first=True) | |
| lens = [item[2] for item in batch_list] | |
| return metadata, wavs, lens | |