ollieollie commited on
Commit
969cb52
·
1 Parent(s): 99cc645

change mel func

Browse files
orator/src/orator/models/voice_encoder/melspec.py CHANGED
@@ -1,75 +1,78 @@
1
  from functools import lru_cache
2
 
 
3
  import numpy as np
4
- import torch
5
- from torchaudio.transforms import MelSpectrogram
6
-
7
- from .config import VoiceEncConfig
8
-
9
-
10
- class ResembleMelSpectrogram(torch.nn.Module):
11
- def __init__(self, hp=VoiceEncConfig()):
12
- """
13
- Torch implementation of Resemble's mel extraction.
14
- Note that the values are NOT identical to librosa's implementation due to floating point precisions, however
15
- the results are very very close. One test file gave an L1 error of just 0.005%, full results:
16
- Librosa mel max: 0.871768
17
- Torch mel max: 0.871768
18
- Librosa mel mean: 0.316302
19
- Torch mel mean: 0.316289
20
- Max diff: 0.061105
21
- Mean diff: 1.453384e-05
22
- Percent error: 0.004595%
23
- """
24
- super().__init__()
25
- self.melspec = MelSpectrogram(
26
- hp.sample_rate,
27
- n_fft=hp.n_fft,
28
- win_length=hp.win_size,
29
- hop_length=hp.hop_size,
30
- f_min=hp.fmin,
31
- f_max=hp.fmax,
32
- n_mels=hp.num_mels,
33
- power=1,
34
- normalized=False,
35
- # NOTE: Folowing librosa's default.
36
- pad_mode="constant",
37
- norm="slaney",
38
- mel_scale="slaney",
39
- )
40
- self.register_buffer(
41
- "stft_magnitude_min",
42
- torch.FloatTensor([hp.stft_magnitude_min])
43
- )
44
- self.min_level_db = 20 * np.log10(hp.stft_magnitude_min)
45
- self.preemphasis = hp.preemphasis
46
- self.hop_size = hp.hop_size
47
-
48
- def forward(self, wav, pad=True):
49
- """
50
- Args:
51
- wav: [B, T]
52
- """
53
- if self.preemphasis > 0:
54
- wav = torch.nn.functional.pad(wav, [1, 0], value=0)
55
- wav = wav[..., 1:] - self.preemphasis * wav[..., :-1]
56
-
57
- mel = self.melspec(wav)
58
-
59
- mel = self._amp_to_db(mel)
60
- mel_normed = self._normalize(mel)
61
- assert not pad or mel_normed.shape[-1] == 1 + \
62
- wav.shape[-1] // self.hop_size # Sanity check
63
- return mel_normed # (M, T)
64
-
65
- def _normalize(self, s, headroom_db=15):
66
- s = (s - self.min_level_db) / (-self.min_level_db + headroom_db)
67
- return s
68
-
69
- def _amp_to_db(self, x):
70
- return 20 * torch.maximum(self.stft_magnitude_min, x).log10()
71
 
72
 
73
  @lru_cache()
74
- def melspectrogram():
75
- return ResembleMelSpectrogram()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from functools import lru_cache
2
 
3
+ from scipy import signal
4
  import numpy as np
5
+ import librosa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  @lru_cache()
9
+ def mel_basis(hp):
10
+ assert hp.fmax <= hp.sample_rate // 2
11
+ return librosa.filters.mel(
12
+ sr=hp.sample_rate,
13
+ n_fft=hp.n_fft,
14
+ n_mels=hp.num_mels,
15
+ fmin=hp.fmin,
16
+ fmax=hp.fmax) # -> (nmel, nfreq)
17
+
18
+
19
+ def preemphasis(wav, hp):
20
+ assert hp.preemphasis != 0
21
+ wav = signal.lfilter([1, -hp.preemphasis], [1], wav)
22
+ wav = np.clip(wav, -1, 1)
23
+ return wav
24
+
25
+
26
+ def melspectrogram(wav, hp, pad=True):
27
+ # Run through pre-emphasis
28
+ if hp.preemphasis > 0:
29
+ wav = preemphasis(wav, hp)
30
+ assert np.abs(wav).max() - 1 < 1e-07
31
+
32
+ # Do the stft
33
+ spec_complex = _stft(wav, hp, pad=pad)
34
+
35
+ # Get the magnitudes
36
+ spec_magnitudes = np.abs(spec_complex)
37
+
38
+ if hp.mel_power != 1.0:
39
+ spec_magnitudes **= hp.mel_power
40
+
41
+ # Get the mel and convert magnitudes->db
42
+ mel = np.dot(mel_basis(hp), spec_magnitudes)
43
+ if hp.mel_type == "db":
44
+ mel = _amp_to_db(mel, hp)
45
+
46
+ # Normalise the mel from db to 0,1
47
+ if hp.normalized_mels:
48
+ mel = _normalize(mel, hp).astype(np.float32)
49
+
50
+ assert not pad or mel.shape[1] == 1 + len(wav) // hp.hop_size # Sanity check
51
+ return mel # (M, T)
52
+
53
+
54
+ def _stft(y, hp, pad=True):
55
+ # NOTE: after 0.8, pad mode defaults to constant, setting this to reflect for
56
+ # historical consistency and streaming-version consistency
57
+ return librosa.stft(
58
+ y,
59
+ n_fft=hp.n_fft,
60
+ hop_length=hp.hop_size,
61
+ win_length=hp.win_size,
62
+ center=pad,
63
+ pad_mode="reflect",
64
+ )
65
+
66
+
67
+ def _amp_to_db(x, hp):
68
+ return 20 * np.log10(np.maximum(hp.stft_magnitude_min, x))
69
+
70
+
71
+ def _db_to_amp(x):
72
+ return np.power(10.0, x * 0.05)
73
+
74
+
75
+ def _normalize(s, hp, headroom_db=15):
76
+ min_level_db = 20 * np.log10(hp.stft_magnitude_min)
77
+ s = (s - min_level_db) / (-min_level_db + headroom_db)
78
+ return s
orator/src/orator/models/voice_encoder/voice_encoder.py CHANGED
@@ -269,8 +269,6 @@ class VoiceEncoder(nn.Module):
269
  if "rate" not in kwargs:
270
  kwargs["rate"] = 1.3 # Resemble's default value.
271
 
272
- mel_func = melspectrogram()
273
- mels = [mel_func(torch.from_numpy(w)
274
- [None])[0].T for w in wavs]
275
 
276
  return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs)
 
269
  if "rate" not in kwargs:
270
  kwargs["rate"] = 1.3 # Resemble's default value.
271
 
272
+ mels = [melspectrogram(w, self.hp).T for w in wavs]
 
 
273
 
274
  return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs)