MassivelyMultilingualTTS
/
Architectures
/ControllabilityGAN
/dataset
/speaker_embeddings_dataset.py
| import os | |
| import numpy as np | |
| import torch | |
| class SpeakerEmbeddingsDataset(torch.utils.data.Dataset): | |
| def __init__(self, feature_path, device, mode='utterance'): | |
| super(SpeakerEmbeddingsDataset, self).__init__() | |
| modes = ['utterance', 'speaker'] | |
| assert mode in modes, f'mode: {mode} is not supported' | |
| if mode == 'utterance': | |
| self.mode = 'utt' | |
| elif mode == 'speaker': | |
| self.mode = 'spk' | |
| self.device = device | |
| self.x, self.speakers = self._load_features(feature_path) | |
| # unique_speakers = set(self.speakers) | |
| # spk2class = dict(zip(unique_speakers, range(len(unique_speakers)))) | |
| # #self.x = self._reformat_features(self.x) | |
| # self.y = torch.tensor([spk2class[spk] for spk in self.speakers]).to(self.device) | |
| # self.class2spk = dict(zip(spk2class.values(), spk2class.keys())) | |
| def __len__(self): | |
| return len(self.speakers) | |
| def __getitem__(self, index): | |
| embedding = self.normalize_embedding(self.x[index]) | |
| # speaker_id = self.y[index] | |
| return embedding, torch.zeros([0]) | |
| def normalize_embedding(self, vector): | |
| return torch.sub(vector, self.mean) / self.std | |
| def get_speaker(self, label): | |
| return self.class2spk[label] | |
| def get_embedding_dim(self): | |
| return self.x.shape[-1] | |
| def get_num_speaker(self): | |
| return len(torch.unique((self.y))) | |
| def set_labels(self, labels): | |
| self.y_old = self.y | |
| self.y = torch.full(size=(len(self),), fill_value=labels).to(self.device) | |
| # if isinstance(labels, int) or isinstance(labels, float): | |
| # self.y = torch.full(size=len(self), fill_value=labels) | |
| # elif len(labels) == len(self): | |
| # self.y = torch.tensor(labels) | |
| def _load_features(self, feature_path): | |
| if os.path.isfile(feature_path): | |
| vectors = torch.load(feature_path, map_location=self.device) | |
| if isinstance(vectors, list): | |
| vectors = torch.stack(vectors) | |
| self.mean = torch.mean(vectors) | |
| self.std = torch.std(vectors) | |
| return vectors, torch.zeros(vectors.size(0)) | |
| else: | |
| vectors = torch.load(feature_path, map_location=self.device) | |
| self.mean = torch.mean(vectors) | |
| self.std = torch.std(vectors) | |
| spk2idx = {} | |
| with open(feature_path / f'{self.mode}2idx', 'r') as f: | |
| for line in f: | |
| split_line = line.strip().split() | |
| if len(split_line) == 2: | |
| spk2idx[split_line[0].strip()] = int(split_line[1]) | |
| speakers, indices = zip(*spk2idx.items()) | |
| if (feature_path / 'utt2spk').exists(): # spk2idx contains utt_ids not speaker_ids | |
| utt2spk = {} | |
| with open(feature_path / 'utt2spk', 'r') as f: | |
| for line in f: | |
| split_line = line.strip().split() | |
| if len(split_line) == 2: | |
| utt2spk[split_line[0].strip()] = split_line[1].strip() | |
| speakers = [utt2spk[utt] for utt in speakers] | |
| return vectors[np.array(indices)], speakers | |
| def _reformat_features(self, features): | |
| if len(features.shape) == 2: | |
| return features.reshape(features.shape[0], 1, 1, features.shape[1]) | |