junseok commited on
Commit
2216a22
·
1 Parent(s): 09bc42b

first commit

Browse files
Files changed (5) hide show
  1. app.py +47 -0
  2. predict.py +119 -0
  3. requirements.txt +7 -0
  4. score.py +102 -0
  5. ssl_ecapa_model.py +314 -0
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from score import load_model
2
+ from predict import loadWav
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import gradio as gr
6
+
7
+ model = load_model("wavlm_ecapa.model")
8
+ model.eval()
9
+
10
+ def calc_voxsim(inp_path, ref_path):
11
+ inp_wavs, inp_wav = loadWav(inp_path)
12
+ ref_wavs, ref_wav = loadWav(ref_path)
13
+
14
+ inp_wavs = torch.FloatTensor(inp_wavs)
15
+ inp_wav = torch.FloatTensor(inp_wav)
16
+ ref_wavs = torch.FloatTensor(ref_wavs)
17
+ ref_wav = torch.FloatTensor(ref_wav)
18
+
19
+ with torch.no_grad():
20
+ input_emb_1 = F.normalize(model.foward(inp_wavs), p=2, dim=1)
21
+ input_emb_2 = F.normalize(model.foward(inp_wav), p=2, dim=1)
22
+ ref_emb_1 = F.normalize(model.foward(ref_wavs), p=2, dim=1)
23
+ ref_emb_2 = F.normalize(model.foward(ref_wav), p=2, dim=1)
24
+
25
+ score_1 = torch.mean(torch.matmul(input_emb_1, ref_emb_1.T))
26
+ score_2 = torch.mean(torch.matmul(input_emb_2, ref_emb_2.T))
27
+ score = (score_1 + score_2) / 2
28
+ return score.detach().cpu().numpy()
29
+
30
+ description = """
31
+ Voice similarity demo using wavlm-ecapa model, which is trained on Voxsim dataset.
32
+ This demo only accepts .wav format. Best at 16 kHz sampling rate.
33
+
34
+ Paper is available [here](https://arxiv.org/abs/2407.18505)
35
+ """
36
+
37
+ iface = gr.Interface(
38
+ fn=calc_voxsim,
39
+ inputs=(
40
+ gr.inputs.Audio(label="Input Audio"),
41
+ gr.inputs.Audio(label="Reference Audio")
42
+ ),
43
+ outputs="text",
44
+ title="voice similarity with VoxSim",
45
+ description=description,
46
+ allow_flagging=False
47
+ )
predict.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pathlib
3
+ import tqdm
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import librosa
6
+ import numpy
7
+ from score import Score
8
+ import torch
9
+
10
+ import warnings
11
+ warnings.filterwarnings("ignore")
12
+
13
+
14
+ def get_arg():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--bs", required=False, default=None, type=int)
17
+ parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str)
18
+ parser.add_argument("--ckpt_path", required=False, default="wavlm_ecapa.model", type=pathlib.Path)
19
+ parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path)
20
+ parser.add_argument("--ref_dir", required=False, default=None, type=pathlib.Path)
21
+ parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path)
22
+ parser.add_argument("--ref_path", required=False, default=None, type=pathlib.Path)
23
+ parser.add_argument("--out_path", required=True, type=pathlib.Path)
24
+ parser.add_argument("--num_workers", required=False, default=0, type=int)
25
+ return parser.parse_args()
26
+
27
+
28
+ def loadWav(filename, max_frames: int = 400):
29
+
30
+ # Maximum audio length
31
+ max_audio = max_frames * 160 + 240
32
+
33
+ # Read wav file and convert to torch tensor
34
+ audio, sr = librosa.load(filename, sr=16000)
35
+ audio_org = audio.copy()
36
+
37
+ audiosize = audio.shape[0]
38
+
39
+ if audiosize <= max_audio:
40
+ shortage = max_audio - audiosize + 1
41
+ audio = numpy.pad(audio, (0, shortage), 'wrap')
42
+ audiosize = audio.shape[0]
43
+
44
+ startframe = numpy.linspace(0,audiosize-max_audio,num=10)
45
+
46
+ feats = []
47
+ for asf in startframe:
48
+ feats.append(audio[int(asf):int(asf)+max_audio])
49
+
50
+ feat = numpy.stack(feats,axis=0).astype(numpy.float32)
51
+
52
+ return torch.FloatTensor(feat), torch.FloatTensor(numpy.stack([audio_org],axis=0).astype(numpy.float32))
53
+
54
+
55
+ class AudioDataset(Dataset):
56
+ def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400):
57
+ self.inp_wavlist = list(inp_dir_path.glob("*.wav"))
58
+ self.ref_wavlist = list(ref_dir_path.glob("*.wav"))
59
+ assert len(self.inp_wavlist) == len(self.ref_wavlist)
60
+ self.inp_wavlist.sort()
61
+ self.ref_wavlist.sort()
62
+ _, self.sr = librosa.load(self.inp_wavlist[0], sr=None)
63
+ self.max_audio = max_frames * 160 + 240
64
+
65
+ def __len__(self):
66
+ return len(self.inp_wavlist)
67
+
68
+ def __getitem__(self, idx):
69
+ inp_wavs, inp_wav = loadWav(self.inp_wavlist[idx])
70
+ ref_wavs, ref_wav = loadWav(self.ref_wavlist[idx])
71
+ return inp_wavs, inp_wav, ref_wavs, ref_wav
72
+
73
+ def main():
74
+ args = get_arg()
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+ if args.mode == "predict_file":
77
+ assert args.inp_path is not None
78
+ assert args.ref_path is not None
79
+ assert args.inp_dir is None
80
+ assert args.ref_dir is None
81
+ assert args.inp_path.exists()
82
+ assert args.inp_path.is_file()
83
+ assert args.ref_path.exists()
84
+ assert args.ref_path.is_file()
85
+ inp_wavs, inp_wav = loadWav(args.inp_path)
86
+ ref_wavs, ref_wav = loadWav(args.ref_path)
87
+ scorer = Score(ckpt_path=args.ckpt_path, device=device)
88
+ score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav)
89
+ print("Voxsim score: ", score[0])
90
+ with open(args.out_path, "w") as fw:
91
+ fw.write(str(score[0]))
92
+ else:
93
+ assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir."
94
+ assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir."
95
+ assert args.bs is not None, "bs is required when mode is predict_dir."
96
+ assert args.inp_path is None, "inp_path should be None"
97
+ assert args.ref_path is None, "ref_path should be None"
98
+ assert args.inp_dir.exists()
99
+ assert args.ref_dir.exists()
100
+ assert args.inp_dir.is_dir()
101
+ assert args.ref_dir.is_dir()
102
+ dataset = AudioDataset(args.inp_dir, args.ref_dir)
103
+ loader = DataLoader(
104
+ dataset,
105
+ batch_size=args.bs,
106
+ shuffle=False,
107
+ num_workers=args.num_workers)
108
+ scorer = Score(ckpt_path=args.ckpt_path, device=device)
109
+ with open(args.out_path, 'w'):
110
+ pass
111
+ for batch in tqdm.tqdm(loader):
112
+ scores = score.score(batch.to(device))
113
+ with open(args.out_path, 'a') as fw:
114
+ for s in scores:
115
+ fw.write(str(s) + "\n")
116
+ print("save to ", args.out_path)
117
+
118
+ if __name__ == "__main__":
119
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy
2
+ librosa
3
+ torch
4
+ torchaudio
5
+ tqdm
6
+ s3prl
7
+ huggingface_hub
score.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from ssl_ecapa_model import SSL_ECAPA_TDNN
5
+ from huggingface_hub import hf_hub_download
6
+
7
+
8
+ def load_model(ckpt_path):
9
+ model = SSL_ECAPA_TDNN(feat_dim=1024, emb_dim=256, feat_type='wavlm_large')
10
+ load_parameters(model, ckpt_path)
11
+ return model
12
+
13
+
14
+ def load_parameters(model, ckpt_path):
15
+ model_state = model.state_dict()
16
+ if not os.path.isfile(ckpt_path):
17
+ print("Downloading model from Hugging Face Hub...")
18
+ new_ckpt_path = hf_hub_download(repo_id="junseok520/voxsim-models", filename=ckpt_path, local_dir="./")
19
+ ckpt_path = new_ckpt_path
20
+ loaded_state = torch.load(ckpt_path, map_location='cpu', weights_only=True)
21
+
22
+ for name, param in loaded_state.items():
23
+ if name.startswith('__S__.'):
24
+ if name[6:] in model_state:
25
+ model_state[name[6:]].copy_(param)
26
+ else:
27
+ print("{} is not in the model.".format(name[6:]))
28
+
29
+
30
+ class Score:
31
+ """Predicting score for each audio clip."""
32
+
33
+ def __init__(
34
+ self,
35
+ ckpt_path: str = "wavlm_ecapa.pt",
36
+ device: str = "gpu"):
37
+ """
38
+ Args:
39
+ ckpt_path: path to pretrained checkpoint of voxsim evaluator.
40
+ input_sample_rate: sampling rate of input audio tensor. The input audio tensor
41
+ is automatically downsampled to 16kHz.
42
+ """
43
+ print(f"Using device: {device}")
44
+ self.device = device
45
+ self.model = load_model(ckpt_path).to(self.device)
46
+ self.model.eval()
47
+
48
+ def score(self, inp_wavs: torch.tensor, inp_wav: torch.tensor, ref_wavs: torch.tensor, ref_wav: torch.tensor) -> torch.tensor:
49
+ """
50
+ Args:
51
+ wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2,
52
+ the model processes the input as a single audio clip. The model
53
+ performs batch processing when len(wavs) == 3.
54
+ """
55
+ # if len(wavs.shape) == 1:
56
+ # out_wavs = wavs.unsqueeze(0).unsqueeze(0)
57
+ # elif len(wavs.shape) == 2:
58
+ # out_wavs = wavs.unsqueeze(0)
59
+ # elif len(wavs.shape) == 3:
60
+ # out_wavs = wavs
61
+ # else:
62
+ # raise ValueError('Dimension of input tensor needs to be <= 3.')
63
+
64
+ if len(inp_wavs.shape) == 2:
65
+ bs = 1
66
+ elif len(inp_wavs.shape) == 3:
67
+ bs = inp_wavs.shape[0]
68
+ else:
69
+ raise ValueError('Dimension of input tensor needs to be <= 3.')
70
+
71
+ inp_wavs = inp_wavs.reshape(-1, inp_wavs.shape[-1]).to(self.device)
72
+ inp_wav = inp_wav.reshape(-1, inp_wav.shape[-1]).to(self.device)
73
+ ref_wavs = ref_wavs.reshape(-1, ref_wavs.shape[-1]).to(self.device)
74
+ ref_wav = ref_wav.reshape(-1, ref_wav.shape[-1]).to(self.device)
75
+
76
+ # assert inp_wavs.shape[1] == 10
77
+ # assert ref_wavs.shape[1] == 10
78
+ # assert inp_wav.shape[1] == 1
79
+ # assert ref_wav.shape[1] == 1
80
+
81
+ # import pdb; pdb.set_trace()
82
+
83
+ with torch.no_grad():
84
+ input_emb_1 = F.normalize(self.model.forward(inp_wavs), p=2, dim=1).detach()
85
+ input_emb_2 = F.normalize(self.model.forward(inp_wav), p=2, dim=1).detach()
86
+ ref_emb_1 = F.normalize(self.model.forward(ref_wavs), p=2, dim=1).detach()
87
+ ref_emb_2 = F.normalize(self.model.forward(ref_wav), p=2, dim=1).detach()
88
+
89
+ emb_size = input_emb_1.shape[-1]
90
+ input_emb_1 = input_emb_1.reshape(bs, -1, emb_size)
91
+ input_emb_2 = input_emb_2.reshape(bs, -1, emb_size)
92
+ ref_emb_1 = ref_emb_1.reshape(bs, -1, emb_size)
93
+ ref_emb_2 = ref_emb_2.reshape(bs, -1, emb_size)
94
+
95
+ score_1 = torch.mean(torch.bmm(input_emb_1, ref_emb_1.transpose(1,2)), dim=(1,2))
96
+ score_2 = torch.mean(torch.bmm(input_emb_2, ref_emb_2.transpose(1,2)), dim=(1,2))
97
+ score = (score_1 + score_2) / 2
98
+ score = score.detach().cpu().numpy()
99
+
100
+ return score
101
+
102
+
ssl_ecapa_model.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchaudio.transforms as trans
7
+
8
+ urls = {
9
+ 'hubert_large_ll60k': "https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k.pt",
10
+ 'xls_r_300m': "https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr2_300m.pt",
11
+ 'unispeech_sat': "https://huggingface.co/s3prl/converted_ckpts/resolve/main/unispeech_sat_large.pt",
12
+ 'wavlm_base_plus': "https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_base_plus.pt",
13
+ 'wavlm_large': "https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt",
14
+ }
15
+
16
+
17
+ ''' Res2Conv1d + BatchNorm1d + ReLU
18
+ '''
19
+
20
+
21
+ class Res2Conv1dReluBn(nn.Module):
22
+ '''
23
+ in_channels == out_channels == channels
24
+ '''
25
+
26
+ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
27
+ super().__init__()
28
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
29
+ self.scale = scale
30
+ self.width = channels // scale
31
+ self.nums = scale if scale == 1 else scale - 1
32
+
33
+ self.convs = []
34
+ self.bns = []
35
+ for i in range(self.nums):
36
+ self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
37
+ self.bns.append(nn.BatchNorm1d(self.width))
38
+ self.convs = nn.ModuleList(self.convs)
39
+ self.bns = nn.ModuleList(self.bns)
40
+
41
+ def forward(self, x):
42
+ out = []
43
+ spx = torch.split(x, self.width, 1)
44
+ for i in range(self.nums):
45
+ if i == 0:
46
+ sp = spx[i]
47
+ else:
48
+ sp = sp + spx[i]
49
+ # Order: conv -> relu -> bn
50
+ sp = self.convs[i](sp)
51
+ sp = self.bns[i](F.relu(sp))
52
+ out.append(sp)
53
+ if self.scale != 1:
54
+ out.append(spx[self.nums])
55
+ out = torch.cat(out, dim=1)
56
+
57
+ return out
58
+
59
+
60
+ ''' Conv1d + BatchNorm1d + ReLU
61
+ '''
62
+
63
+
64
+ class Conv1dReluBn(nn.Module):
65
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
66
+ super().__init__()
67
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
68
+ self.bn = nn.BatchNorm1d(out_channels)
69
+
70
+ def forward(self, x):
71
+ return self.bn(F.relu(self.conv(x)))
72
+
73
+
74
+ ''' The SE connection of 1D case.
75
+ '''
76
+
77
+
78
+ class SE_Connect(nn.Module):
79
+ def __init__(self, channels, se_bottleneck_dim=128):
80
+ super().__init__()
81
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
82
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
83
+
84
+ def forward(self, x):
85
+ out = x.mean(dim=2)
86
+ out = F.relu(self.linear1(out))
87
+ out = torch.sigmoid(self.linear2(out))
88
+ out = x * out.unsqueeze(2)
89
+
90
+ return out
91
+
92
+
93
+ ''' SE-Res2Block of the ECAPA-TDNN architecture.
94
+ '''
95
+
96
+
97
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
98
+ # return nn.Sequential(
99
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
100
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
101
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
102
+ # SE_Connect(channels)
103
+ # )
104
+
105
+
106
+ class SE_Res2Block(nn.Module):
107
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
108
+ super().__init__()
109
+ self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
110
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
111
+ self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
112
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
113
+
114
+ self.shortcut = None
115
+ if in_channels != out_channels:
116
+ self.shortcut = nn.Conv1d(
117
+ in_channels=in_channels,
118
+ out_channels=out_channels,
119
+ kernel_size=1,
120
+ )
121
+
122
+ def forward(self, x):
123
+ residual = x
124
+ if self.shortcut:
125
+ residual = self.shortcut(x)
126
+
127
+ x = self.Conv1dReluBn1(x)
128
+ x = self.Res2Conv1dReluBn(x)
129
+ x = self.Conv1dReluBn2(x)
130
+ x = self.SE_Connect(x)
131
+
132
+ return x + residual
133
+
134
+
135
+ ''' Attentive weighted mean and standard deviation pooling.
136
+ '''
137
+
138
+
139
+ class AttentiveStatsPool(nn.Module):
140
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
141
+ super().__init__()
142
+ self.global_context_att = global_context_att
143
+
144
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
145
+ if global_context_att:
146
+ self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
147
+ else:
148
+ self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
149
+ self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
150
+
151
+ def forward(self, x):
152
+
153
+ if self.global_context_att:
154
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
155
+ context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
156
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
157
+ else:
158
+ x_in = x
159
+
160
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
161
+ alpha = torch.tanh(self.linear1(x_in))
162
+ # alpha = F.relu(self.linear1(x_in))
163
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
164
+ mean = torch.sum(alpha * x, dim=2)
165
+ residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
166
+ std = torch.sqrt(residuals.clamp(min=1e-9))
167
+ return torch.cat([mean, std], dim=1)
168
+
169
+
170
+ class SSL_ECAPA_TDNN(nn.Module):
171
+ def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
172
+ feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, initial_model="", **kwargs):
173
+ super().__init__()
174
+
175
+ self.feat_type = feat_type
176
+ self.feature_selection = feature_selection
177
+ self.update_extract = update_extract
178
+ self.sr = sr
179
+
180
+ if feat_type == "fbank" or feat_type == "mfcc":
181
+ self.update_extract = False
182
+
183
+ win_len = int(sr * 0.025)
184
+ hop_len = int(sr * 0.01)
185
+
186
+ if feat_type == 'fbank':
187
+ self.feature_extract = trans.MelSpectrogram(sample_rate=sr, n_fft=512, win_length=win_len,
188
+ hop_length=hop_len, f_min=0.0, f_max=sr // 2,
189
+ pad=0, n_mels=feat_dim)
190
+ elif feat_type == 'mfcc':
191
+ melkwargs = {
192
+ 'n_fft': 512,
193
+ 'win_length': win_len,
194
+ 'hop_length': hop_len,
195
+ 'f_min': 0.0,
196
+ 'f_max': sr // 2,
197
+ 'pad': 0
198
+ }
199
+ self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False,
200
+ melkwargs=melkwargs)
201
+ else:
202
+ self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
203
+
204
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
205
+ self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
206
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
207
+ self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
208
+
209
+ self.feat_num = self.get_feat_num()
210
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
211
+ # self.feature_weight = nn.Parameter(torch.zeros(7))
212
+
213
+ if feat_type != 'fbank' and feat_type != 'mfcc':
214
+ freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
215
+ for name, param in self.feature_extract.named_parameters():
216
+ for freeze_val in freeze_list:
217
+ if freeze_val in name:
218
+ param.requires_grad = False
219
+ break
220
+
221
+ if not self.update_extract:
222
+ for param in self.feature_extract.parameters():
223
+ param.requires_grad = False
224
+
225
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
226
+ # self.channels = [channels] * 4 + [channels * 3]
227
+ self.channels = [channels] * 4 + [1536]
228
+
229
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
230
+ self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
231
+ self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
232
+ self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
233
+
234
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
235
+ cat_channels = channels * 3
236
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
237
+ self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
238
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
239
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
240
+
241
+
242
+ def get_feat_num(self):
243
+ self.feature_extract.eval()
244
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
245
+ with torch.no_grad():
246
+ features = self.feature_extract(wav)
247
+ select_feature = features[self.feature_selection]
248
+ if isinstance(select_feature, (list, tuple)):
249
+ return len(select_feature)
250
+ else:
251
+ return 1
252
+
253
+ def get_feat(self, x):
254
+ if self.update_extract:
255
+ x = self.feature_extract([sample for sample in x])
256
+ else:
257
+ with torch.no_grad():
258
+ if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
259
+ x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
260
+ else:
261
+ x = self.feature_extract([sample for sample in x])
262
+
263
+ if self.feat_type == 'fbank':
264
+ x = x.log()
265
+
266
+ if self.feat_type != "fbank" and self.feat_type != "mfcc":
267
+ x = x[self.feature_selection]
268
+ # x = x[1:8]
269
+ # x = x[2]
270
+ if isinstance(x, (list, tuple)):
271
+ x = torch.stack(x, dim=0)
272
+ else:
273
+ x = x.unsqueeze(0)
274
+ norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
275
+ # norm_weights = F.softmax(self.feature_weight[1:8], dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
276
+ x = (norm_weights * x).sum(dim=0)
277
+ x = torch.transpose(x, 1, 2) + 1e-6
278
+
279
+ x = self.instance_norm(x)
280
+ return x
281
+
282
+ def forward(self, x):
283
+ x = self.get_feat(x)
284
+
285
+ out1 = self.layer1(x)
286
+ out2 = self.layer2(out1)
287
+ out3 = self.layer3(out2)
288
+ out4 = self.layer4(out3)
289
+
290
+ out = torch.cat([out2, out3, out4], dim=1)
291
+ out = F.relu(self.conv(out))
292
+ out = self.bn(self.pooling(out))
293
+ out = self.linear(out)
294
+
295
+ return out
296
+
297
+
298
+ def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False):
299
+ return SSL_ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
300
+ feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract)
301
+
302
+
303
+ def wavlm_ecapa():
304
+ return SSL_ECAPA_TDNN(feat_dim=1024, emb_dim=256, feat_type='wavlm_large')
305
+
306
+
307
+ if __name__ == '__main__':
308
+ x = torch.zeros(2, 32000)
309
+ model = SSL_ECAPA_TDNN(feat_dim=1024, emb_dim=256, feat_type='wavlm_large', feature_selection="hidden_states",
310
+ update_extract=False, ssl_weight=False)
311
+ import pdb; pdb.set_trace()
312
+ out = model(x)
313
+ # print(model)
314
+ print(out.shape)