Spaces:
Sleeping
Sleeping
Upload 100 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +6 -5
- app.py +73 -0
- common/audio_processing.py +120 -0
- common/env.py +25 -0
- common/filter_warnings.py +33 -0
- common/gpu_affinity.py +156 -0
- common/layers.py +134 -0
- common/repeated_dataloader.py +59 -0
- common/stft.py +140 -0
- common/tb_dllogger.py +172 -0
- common/text/LICENSE +19 -0
- common/text/__init__.py +3 -0
- common/text/__pycache__/__init__.cpython-37.pyc +0 -0
- common/text/__pycache__/__init__.cpython-38.pyc +0 -0
- common/text/__pycache__/__init__.cpython-39.pyc +0 -0
- common/text/__pycache__/abbreviations.cpython-37.pyc +0 -0
- common/text/__pycache__/abbreviations.cpython-38.pyc +0 -0
- common/text/__pycache__/abbreviations.cpython-39.pyc +0 -0
- common/text/__pycache__/acronyms.cpython-37.pyc +0 -0
- common/text/__pycache__/acronyms.cpython-38.pyc +0 -0
- common/text/__pycache__/acronyms.cpython-39.pyc +0 -0
- common/text/__pycache__/cleaners.cpython-37.pyc +0 -0
- common/text/__pycache__/cleaners.cpython-38.pyc +0 -0
- common/text/__pycache__/cleaners.cpython-39.pyc +0 -0
- common/text/__pycache__/cmudict.cpython-37.pyc +0 -0
- common/text/__pycache__/cmudict.cpython-38.pyc +0 -0
- common/text/__pycache__/cmudict.cpython-39.pyc +0 -0
- common/text/__pycache__/datestime.cpython-37.pyc +0 -0
- common/text/__pycache__/datestime.cpython-38.pyc +0 -0
- common/text/__pycache__/datestime.cpython-39.pyc +0 -0
- common/text/__pycache__/letters_and_numbers.cpython-37.pyc +0 -0
- common/text/__pycache__/letters_and_numbers.cpython-38.pyc +0 -0
- common/text/__pycache__/letters_and_numbers.cpython-39.pyc +0 -0
- common/text/__pycache__/numerical.cpython-37.pyc +0 -0
- common/text/__pycache__/numerical.cpython-38.pyc +0 -0
- common/text/__pycache__/numerical.cpython-39.pyc +0 -0
- common/text/__pycache__/symbols.cpython-37.pyc +0 -0
- common/text/__pycache__/symbols.cpython-38.pyc +0 -0
- common/text/__pycache__/symbols.cpython-39.pyc +0 -0
- common/text/__pycache__/text_processing.cpython-37.pyc +0 -0
- common/text/__pycache__/text_processing.cpython-38.pyc +0 -0
- common/text/__pycache__/text_processing.cpython-39.pyc +0 -0
- common/text/abbreviations.py +67 -0
- common/text/acronyms.py +109 -0
- common/text/cleaners.py +102 -0
- common/text/cmudict.py +98 -0
- common/text/datestime.py +22 -0
- common/text/letters_and_numbers.py +90 -0
- common/text/numerical.py +153 -0
- common/text/symbols.py +81 -0
README.md
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: green
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: cc-by-nc-nd-4.0
|
11 |
-
|
|
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Multi Sami
|
3 |
+
emoji: 🔥
|
4 |
colorFrom: green
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.15.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: cc-by-nc-nd-4.0
|
11 |
+
#license: cc-by-4.0
|
12 |
+
short_description: Multilingual, multi-speaker Sámi TTS
|
13 |
---
|
14 |
|
15 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import syn_hifigan as syn
|
3 |
+
#import syn_k_univnet_multi as syn
|
4 |
+
import os, tempfile
|
5 |
+
|
6 |
+
languages = {"South Sámi":0,
|
7 |
+
"North Sámi":1,
|
8 |
+
"Lule Sámi":2}
|
9 |
+
|
10 |
+
speakers={#"aj0": 0,
|
11 |
+
"Aanna - sma": 1,
|
12 |
+
"Máhtte": 2,
|
13 |
+
"Siggá - smj": 3,
|
14 |
+
"Biret - sme": 5,
|
15 |
+
#"lo": 6,
|
16 |
+
"Sunná": 7,
|
17 |
+
"Abmut - smj": 8,
|
18 |
+
"Nihkol - smj": 9
|
19 |
+
}
|
20 |
+
public=True
|
21 |
+
|
22 |
+
tempdir = tempfile.gettempdir()
|
23 |
+
|
24 |
+
tts = syn.Synthesizer()
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
def speak(text, language,speaker,l_weight, s_weight, pace, postfilter): #pitch_shift,pitch_std):
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
# text frontend not implemented...
|
33 |
+
text = text.replace("...", "…")
|
34 |
+
print(speakers[speaker])
|
35 |
+
audio = tts.speak(text, output_file=f'{tempdir}/tmp', lang=languages[language],
|
36 |
+
spkr=speakers[speaker], l_weight=l_weight, s_weight=s_weight,
|
37 |
+
pace=pace, clarity=postfilter)
|
38 |
+
|
39 |
+
if not public:
|
40 |
+
try:
|
41 |
+
os.system("play "+tempdir+"/tmp.wav &")
|
42 |
+
except:
|
43 |
+
pass
|
44 |
+
|
45 |
+
return (22050, audio)
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
controls = []
|
50 |
+
controls.append(gr.Textbox(label="text", value="Suohtas duinna deaivvadit."))
|
51 |
+
controls.append(gr.Dropdown(list(languages.keys()), label="language", value="North Sámi"))
|
52 |
+
controls.append(gr.Dropdown(list(speakers.keys()), label="speaker", value="Sunná"))
|
53 |
+
|
54 |
+
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="language weight"))
|
55 |
+
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="speaker weight"))
|
56 |
+
|
57 |
+
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1.0, label="speech rate"))
|
58 |
+
controls.append(gr.Slider(minimum=0., maximum=2, step=0.05, value=1.0, label="post-processing"))
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
tts_gui = gr.Interface(
|
64 |
+
fn=speak,
|
65 |
+
inputs=controls,
|
66 |
+
outputs= gr.Audio(label="output"),
|
67 |
+
live=False
|
68 |
+
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
tts_gui.launch(share=public)
|
common/audio_processing.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# *****************************************************************************
|
2 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Redistribution and use in source and binary forms, with or without
|
5 |
+
# modification, are permitted provided that the following conditions are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of the NVIDIA CORPORATION nor the
|
12 |
+
# names of its contributors may be used to endorse or promote products
|
13 |
+
# derived from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
16 |
+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
17 |
+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
+
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
19 |
+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
+
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
22 |
+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
+
#
|
26 |
+
# *****************************************************************************
|
27 |
+
|
28 |
+
import librosa.util as librosa_util
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
from scipy.signal import get_window
|
32 |
+
|
33 |
+
|
34 |
+
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
35 |
+
n_fft=800, dtype=np.float32, norm=None):
|
36 |
+
"""
|
37 |
+
# from librosa 0.6
|
38 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
39 |
+
|
40 |
+
This is used to estimate modulation effects induced by windowing
|
41 |
+
observations in short-time fourier transforms.
|
42 |
+
|
43 |
+
Parameters
|
44 |
+
----------
|
45 |
+
window : string, tuple, number, callable, or list-like
|
46 |
+
Window specification, as in `get_window`
|
47 |
+
|
48 |
+
n_frames : int > 0
|
49 |
+
The number of analysis frames
|
50 |
+
|
51 |
+
hop_length : int > 0
|
52 |
+
The number of samples to advance between frames
|
53 |
+
|
54 |
+
win_length : [optional]
|
55 |
+
The length of the window function. By default, this matches `n_fft`.
|
56 |
+
|
57 |
+
n_fft : int > 0
|
58 |
+
The length of each analysis frame.
|
59 |
+
|
60 |
+
dtype : np.dtype
|
61 |
+
The data type of the output
|
62 |
+
|
63 |
+
Returns
|
64 |
+
-------
|
65 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
66 |
+
The sum-squared envelope of the window function
|
67 |
+
"""
|
68 |
+
if win_length is None:
|
69 |
+
win_length = n_fft
|
70 |
+
|
71 |
+
n = n_fft + hop_length * (n_frames - 1)
|
72 |
+
x = np.zeros(n, dtype=dtype)
|
73 |
+
|
74 |
+
# Compute the squared window at the desired length
|
75 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
76 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
|
77 |
+
win_sq = librosa_util.pad_center(win_sq, size=n_fft)
|
78 |
+
|
79 |
+
# Fill the envelope
|
80 |
+
for i in range(n_frames):
|
81 |
+
sample = i * hop_length
|
82 |
+
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
def griffin_lim(magnitudes, stft_fn, n_iters=30):
|
87 |
+
"""
|
88 |
+
PARAMS
|
89 |
+
------
|
90 |
+
magnitudes: spectrogram magnitudes
|
91 |
+
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
|
92 |
+
"""
|
93 |
+
|
94 |
+
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
|
95 |
+
angles = angles.astype(np.float32)
|
96 |
+
angles = torch.autograd.Variable(torch.from_numpy(angles))
|
97 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
98 |
+
|
99 |
+
for i in range(n_iters):
|
100 |
+
_, angles = stft_fn.transform(signal)
|
101 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
102 |
+
return signal
|
103 |
+
|
104 |
+
|
105 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
106 |
+
"""
|
107 |
+
PARAMS
|
108 |
+
------
|
109 |
+
C: compression factor
|
110 |
+
"""
|
111 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
112 |
+
|
113 |
+
|
114 |
+
def dynamic_range_decompression(x, C=1):
|
115 |
+
"""
|
116 |
+
PARAMS
|
117 |
+
------
|
118 |
+
C: compression factor used to compress
|
119 |
+
"""
|
120 |
+
return torch.exp(x) / C
|
common/env.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
from collections import defaultdict
|
4 |
+
|
5 |
+
|
6 |
+
class AttrDict(dict):
|
7 |
+
def __init__(self, *args, **kwargs):
|
8 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
9 |
+
self.__dict__ = self
|
10 |
+
|
11 |
+
|
12 |
+
class DefaultAttrDict(defaultdict):
|
13 |
+
def __init__(self, *args, **kwargs):
|
14 |
+
super(DefaultAttrDict, self).__init__(*args, **kwargs)
|
15 |
+
self.__dict__ = self
|
16 |
+
|
17 |
+
def __getattr__(self, item):
|
18 |
+
return self[item]
|
19 |
+
|
20 |
+
|
21 |
+
def build_env(config, config_name, path):
|
22 |
+
t_path = os.path.join(path, config_name)
|
23 |
+
if config != t_path:
|
24 |
+
os.makedirs(path, exist_ok=True)
|
25 |
+
shutil.copyfile(config, os.path.join(path, config_name))
|
common/filter_warnings.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Mutes known and unrelated PyTorch warnings.
|
16 |
+
|
17 |
+
The warnings module keeps a list of filters. Importing it as late as possible
|
18 |
+
prevents its filters from being overriden.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import warnings
|
22 |
+
|
23 |
+
|
24 |
+
# NGC 22.04-py3 container (PyTorch 1.12.0a0+bd13bc6)
|
25 |
+
warnings.filterwarnings(
|
26 |
+
"ignore",
|
27 |
+
message='positional arguments and argument "destination" are deprecated.'
|
28 |
+
' nn.Module.state_dict will not accept them in the future.')
|
29 |
+
|
30 |
+
# 22.08-py3 container
|
31 |
+
warnings.filterwarnings(
|
32 |
+
"ignore",
|
33 |
+
message="is_namedtuple is deprecated, please use the python checks")
|
common/gpu_affinity.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import collections
|
16 |
+
import math
|
17 |
+
import os
|
18 |
+
import pathlib
|
19 |
+
import re
|
20 |
+
|
21 |
+
import pynvml
|
22 |
+
|
23 |
+
pynvml.nvmlInit()
|
24 |
+
|
25 |
+
|
26 |
+
def systemGetDriverVersion():
|
27 |
+
return pynvml.nvmlSystemGetDriverVersion()
|
28 |
+
|
29 |
+
|
30 |
+
def deviceGetCount():
|
31 |
+
return pynvml.nvmlDeviceGetCount()
|
32 |
+
|
33 |
+
|
34 |
+
class device:
|
35 |
+
# assume nvml returns list of 64 bit ints
|
36 |
+
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
|
37 |
+
|
38 |
+
def __init__(self, device_idx):
|
39 |
+
super().__init__()
|
40 |
+
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
|
41 |
+
|
42 |
+
def getName(self):
|
43 |
+
return pynvml.nvmlDeviceGetName(self.handle)
|
44 |
+
|
45 |
+
def getCpuAffinity(self):
|
46 |
+
affinity_string = ''
|
47 |
+
for j in pynvml.nvmlDeviceGetCpuAffinity(
|
48 |
+
self.handle, device._nvml_affinity_elements
|
49 |
+
):
|
50 |
+
# assume nvml returns list of 64 bit ints
|
51 |
+
affinity_string = '{:064b}'.format(j) + affinity_string
|
52 |
+
affinity_list = [int(x) for x in affinity_string]
|
53 |
+
affinity_list.reverse() # so core 0 is in 0th element of list
|
54 |
+
|
55 |
+
ret = [i for i, e in enumerate(affinity_list) if e != 0]
|
56 |
+
return ret
|
57 |
+
|
58 |
+
|
59 |
+
def set_socket_affinity(gpu_id):
|
60 |
+
dev = device(gpu_id)
|
61 |
+
affinity = dev.getCpuAffinity()
|
62 |
+
os.sched_setaffinity(0, affinity)
|
63 |
+
|
64 |
+
|
65 |
+
def set_single_affinity(gpu_id):
|
66 |
+
dev = device(gpu_id)
|
67 |
+
affinity = dev.getCpuAffinity()
|
68 |
+
os.sched_setaffinity(0, affinity[:1])
|
69 |
+
|
70 |
+
|
71 |
+
def set_single_unique_affinity(gpu_id, nproc_per_node):
|
72 |
+
devices = [device(i) for i in range(nproc_per_node)]
|
73 |
+
socket_affinities = [dev.getCpuAffinity() for dev in devices]
|
74 |
+
|
75 |
+
siblings_list = get_thread_siblings_list()
|
76 |
+
siblings_dict = dict(siblings_list)
|
77 |
+
|
78 |
+
# remove siblings
|
79 |
+
for idx, socket_affinity in enumerate(socket_affinities):
|
80 |
+
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
|
81 |
+
|
82 |
+
affinities = []
|
83 |
+
assigned = []
|
84 |
+
|
85 |
+
for socket_affinity in socket_affinities:
|
86 |
+
for core in socket_affinity:
|
87 |
+
if core not in assigned:
|
88 |
+
affinities.append([core])
|
89 |
+
assigned.append(core)
|
90 |
+
break
|
91 |
+
os.sched_setaffinity(0, affinities[gpu_id])
|
92 |
+
|
93 |
+
|
94 |
+
def set_socket_unique_affinity(gpu_id, nproc_per_node, mode):
|
95 |
+
device_ids = [device(i) for i in range(nproc_per_node)]
|
96 |
+
socket_affinities = [dev.getCpuAffinity() for dev in device_ids]
|
97 |
+
|
98 |
+
siblings_list = get_thread_siblings_list()
|
99 |
+
siblings_dict = dict(siblings_list)
|
100 |
+
|
101 |
+
# remove siblings
|
102 |
+
for idx, socket_affinity in enumerate(socket_affinities):
|
103 |
+
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
|
104 |
+
|
105 |
+
socket_affinities_to_device_ids = collections.defaultdict(list)
|
106 |
+
|
107 |
+
for idx, socket_affinity in enumerate(socket_affinities):
|
108 |
+
socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)
|
109 |
+
|
110 |
+
for socket_affinity, device_ids in socket_affinities_to_device_ids.items():
|
111 |
+
devices_per_group = len(device_ids)
|
112 |
+
cores_per_device = len(socket_affinity) // devices_per_group
|
113 |
+
for group_id, device_id in enumerate(device_ids):
|
114 |
+
if device_id == gpu_id:
|
115 |
+
if mode == 'interleaved':
|
116 |
+
affinity = list(socket_affinity[group_id::devices_per_group])
|
117 |
+
elif mode == 'continuous':
|
118 |
+
affinity = list(socket_affinity[group_id*cores_per_device:(group_id+1)*cores_per_device])
|
119 |
+
else:
|
120 |
+
raise RuntimeError('Unknown set_socket_unique_affinity mode')
|
121 |
+
|
122 |
+
# reintroduce siblings
|
123 |
+
affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]
|
124 |
+
os.sched_setaffinity(0, affinity)
|
125 |
+
|
126 |
+
|
127 |
+
def get_thread_siblings_list():
|
128 |
+
path = '/sys/devices/system/cpu/cpu*/topology/thread_siblings_list'
|
129 |
+
thread_siblings_list = []
|
130 |
+
pattern = re.compile(r'(\d+)\D(\d+)')
|
131 |
+
for fname in pathlib.Path(path[0]).glob(path[1:]):
|
132 |
+
with open(fname) as f:
|
133 |
+
content = f.read().strip()
|
134 |
+
res = pattern.findall(content)
|
135 |
+
if res:
|
136 |
+
pair = tuple(map(int, res[0]))
|
137 |
+
thread_siblings_list.append(pair)
|
138 |
+
return thread_siblings_list
|
139 |
+
|
140 |
+
|
141 |
+
def set_affinity(gpu_id, nproc_per_node, mode='socket'):
|
142 |
+
if mode == 'socket':
|
143 |
+
set_socket_affinity(gpu_id)
|
144 |
+
elif mode == 'single':
|
145 |
+
set_single_affinity(gpu_id)
|
146 |
+
elif mode == 'single_unique':
|
147 |
+
set_single_unique_affinity(gpu_id, nproc_per_node)
|
148 |
+
elif mode == 'socket_unique_interleaved':
|
149 |
+
set_socket_unique_affinity(gpu_id, nproc_per_node, 'interleaved')
|
150 |
+
elif mode == 'socket_unique_continuous':
|
151 |
+
set_socket_unique_affinity(gpu_id, nproc_per_node, 'continuous')
|
152 |
+
else:
|
153 |
+
raise RuntimeError('Unknown affinity mode')
|
154 |
+
|
155 |
+
affinity = os.sched_getaffinity(0)
|
156 |
+
return affinity
|
common/layers.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# *****************************************************************************
|
2 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Redistribution and use in source and binary forms, with or without
|
5 |
+
# modification, are permitted provided that the following conditions are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of the NVIDIA CORPORATION nor the
|
12 |
+
# names of its contributors may be used to endorse or promote products
|
13 |
+
# derived from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
16 |
+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
17 |
+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
+
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
19 |
+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
+
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
22 |
+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
+
#
|
26 |
+
# *****************************************************************************
|
27 |
+
|
28 |
+
import torch
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from librosa.filters import mel as librosa_mel_fn
|
31 |
+
|
32 |
+
from common.audio_processing import (dynamic_range_compression,
|
33 |
+
dynamic_range_decompression)
|
34 |
+
from common.stft import STFT
|
35 |
+
|
36 |
+
|
37 |
+
class LinearNorm(torch.nn.Module):
|
38 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
39 |
+
super(LinearNorm, self).__init__()
|
40 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
41 |
+
|
42 |
+
torch.nn.init.xavier_uniform_(
|
43 |
+
self.linear_layer.weight,
|
44 |
+
gain=torch.nn.init.calculate_gain(w_init_gain))
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
return self.linear_layer(x)
|
48 |
+
|
49 |
+
|
50 |
+
class ConvNorm(torch.nn.Module):
|
51 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
52 |
+
padding=None, dilation=1, bias=True, w_init_gain='linear',
|
53 |
+
batch_norm=False):
|
54 |
+
super(ConvNorm, self).__init__()
|
55 |
+
if padding is None:
|
56 |
+
assert(kernel_size % 2 == 1)
|
57 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
58 |
+
|
59 |
+
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
60 |
+
kernel_size=kernel_size, stride=stride,
|
61 |
+
padding=padding, dilation=dilation,
|
62 |
+
bias=bias)
|
63 |
+
self.norm = torch.nn.BatchNorm1D(out_channels) if batch_norm else None
|
64 |
+
|
65 |
+
torch.nn.init.xavier_uniform_(
|
66 |
+
self.conv.weight,
|
67 |
+
gain=torch.nn.init.calculate_gain(w_init_gain))
|
68 |
+
|
69 |
+
def forward(self, signal):
|
70 |
+
if self.norm is None:
|
71 |
+
return self.conv(signal)
|
72 |
+
else:
|
73 |
+
return self.norm(self.conv(signal))
|
74 |
+
|
75 |
+
|
76 |
+
class ConvReLUNorm(torch.nn.Module):
|
77 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0):
|
78 |
+
super(ConvReLUNorm, self).__init__()
|
79 |
+
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
80 |
+
kernel_size=kernel_size,
|
81 |
+
padding=(kernel_size // 2))
|
82 |
+
self.norm = torch.nn.LayerNorm(out_channels)
|
83 |
+
self.dropout = torch.nn.Dropout(dropout)
|
84 |
+
|
85 |
+
def forward(self, signal):
|
86 |
+
out = F.relu(self.conv(signal))
|
87 |
+
out = self.norm(out.transpose(1, 2)).transpose(1, 2).to(signal.dtype)
|
88 |
+
return self.dropout(out)
|
89 |
+
|
90 |
+
|
91 |
+
class TacotronSTFT(torch.nn.Module):
|
92 |
+
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
|
93 |
+
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
|
94 |
+
mel_fmax=8000.0):
|
95 |
+
super(TacotronSTFT, self).__init__()
|
96 |
+
self.n_mel_channels = n_mel_channels
|
97 |
+
self.sampling_rate = sampling_rate
|
98 |
+
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
99 |
+
mel_basis = librosa_mel_fn(
|
100 |
+
sr=sampling_rate,
|
101 |
+
n_fft=filter_length,
|
102 |
+
n_mels=n_mel_channels,
|
103 |
+
fmin=mel_fmin,
|
104 |
+
fmax=mel_fmax
|
105 |
+
)
|
106 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
107 |
+
self.register_buffer('mel_basis', mel_basis)
|
108 |
+
|
109 |
+
def spectral_normalize(self, magnitudes):
|
110 |
+
output = dynamic_range_compression(magnitudes)
|
111 |
+
return output
|
112 |
+
|
113 |
+
def spectral_de_normalize(self, magnitudes):
|
114 |
+
output = dynamic_range_decompression(magnitudes)
|
115 |
+
return output
|
116 |
+
|
117 |
+
def mel_spectrogram(self, y):
|
118 |
+
"""Computes mel-spectrograms from a batch of waves
|
119 |
+
PARAMS
|
120 |
+
------
|
121 |
+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
122 |
+
|
123 |
+
RETURNS
|
124 |
+
-------
|
125 |
+
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
126 |
+
"""
|
127 |
+
assert(torch.min(y.data) >= -1)
|
128 |
+
assert(torch.max(y.data) <= 1)
|
129 |
+
|
130 |
+
magnitudes, phases = self.stft_fn.transform(y)
|
131 |
+
magnitudes = magnitudes.data
|
132 |
+
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
133 |
+
mel_output = self.spectral_normalize(mel_output)
|
134 |
+
return mel_output
|
common/repeated_dataloader.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Data pipeline elements which wrap the data N times
|
16 |
+
|
17 |
+
A RepeatedDataLoader resets its iterator less frequently. This saves time
|
18 |
+
on multi-GPU platforms and is invisible to the training loop.
|
19 |
+
|
20 |
+
NOTE: Repeating puts a block of (len(dataset) * repeats) int64s into RAM.
|
21 |
+
Do not use more repeats than necessary (e.g., 10**6 to simulate infinity).
|
22 |
+
"""
|
23 |
+
|
24 |
+
import itertools
|
25 |
+
|
26 |
+
from torch.utils.data import DataLoader
|
27 |
+
from torch.utils.data.distributed import DistributedSampler
|
28 |
+
|
29 |
+
|
30 |
+
class RepeatedDataLoader(DataLoader):
|
31 |
+
def __init__(self, repeats, *args, **kwargs):
|
32 |
+
self.repeats = repeats
|
33 |
+
super().__init__(*args, **kwargs)
|
34 |
+
|
35 |
+
def __iter__(self):
|
36 |
+
if self._iterator is None or self.repeats_done >= self.repeats:
|
37 |
+
self.repeats_done = 1
|
38 |
+
return super().__iter__()
|
39 |
+
else:
|
40 |
+
self.repeats_done += 1
|
41 |
+
return self._iterator
|
42 |
+
|
43 |
+
|
44 |
+
class RepeatedDistributedSampler(DistributedSampler):
|
45 |
+
def __init__(self, repeats, *args, **kwargs):
|
46 |
+
self.repeats = repeats
|
47 |
+
assert self.repeats <= 10000, "Too many repeats overload RAM."
|
48 |
+
super().__init__(*args, **kwargs)
|
49 |
+
|
50 |
+
def __iter__(self):
|
51 |
+
# Draw indices for `self.repeats` epochs forward
|
52 |
+
start_epoch = self.epoch
|
53 |
+
iters = []
|
54 |
+
for r in range(self.repeats):
|
55 |
+
self.set_epoch(start_epoch + r)
|
56 |
+
iters.append(super().__iter__())
|
57 |
+
self.set_epoch(start_epoch)
|
58 |
+
|
59 |
+
return itertools.chain.from_iterable(iters)
|
common/stft.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BSD 3-Clause License
|
3 |
+
|
4 |
+
Copyright (c) 2017, Prem Seetharaman
|
5 |
+
All rights reserved.
|
6 |
+
|
7 |
+
* Redistribution and use in source and binary forms, with or without
|
8 |
+
modification, are permitted provided that the following conditions are met:
|
9 |
+
|
10 |
+
* Redistributions of source code must retain the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer.
|
12 |
+
|
13 |
+
* Redistributions in binary form must reproduce the above copyright notice, this
|
14 |
+
list of conditions and the following disclaimer in the
|
15 |
+
documentation and/or other materials provided with the distribution.
|
16 |
+
|
17 |
+
* Neither the name of the copyright holder nor the names of its
|
18 |
+
contributors may be used to endorse or promote products derived from this
|
19 |
+
software without specific prior written permission.
|
20 |
+
|
21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
22 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
23 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
25 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
26 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
27 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
28 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
29 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
30 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31 |
+
"""
|
32 |
+
|
33 |
+
import torch
|
34 |
+
import numpy as np
|
35 |
+
import torch.nn.functional as F
|
36 |
+
from torch.autograd import Variable
|
37 |
+
from scipy.signal import get_window
|
38 |
+
from librosa.util import pad_center, tiny
|
39 |
+
from common.audio_processing import window_sumsquare
|
40 |
+
|
41 |
+
|
42 |
+
class STFT(torch.nn.Module):
|
43 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
44 |
+
def __init__(self, filter_length=800, hop_length=200, win_length=800,
|
45 |
+
window='hann', device="cpu"):
|
46 |
+
super(STFT, self).__init__()
|
47 |
+
self.filter_length = filter_length
|
48 |
+
self.hop_length = hop_length
|
49 |
+
self.win_length = win_length
|
50 |
+
self.window = window
|
51 |
+
self.forward_transform = None
|
52 |
+
scale = self.filter_length / self.hop_length
|
53 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
54 |
+
|
55 |
+
cutoff = int((self.filter_length / 2 + 1))
|
56 |
+
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
|
57 |
+
np.imag(fourier_basis[:cutoff, :])])
|
58 |
+
|
59 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
60 |
+
inverse_basis = torch.FloatTensor(
|
61 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :].copy())
|
62 |
+
|
63 |
+
if window is not None:
|
64 |
+
assert(filter_length >= win_length)
|
65 |
+
# get window and zero center pad it to filter_length
|
66 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
67 |
+
fft_window = pad_center(fft_window, size=filter_length)
|
68 |
+
fft_window = torch.from_numpy(fft_window).float()
|
69 |
+
|
70 |
+
# window the bases
|
71 |
+
forward_basis *= fft_window
|
72 |
+
inverse_basis *= fft_window
|
73 |
+
|
74 |
+
self.register_buffer('forward_basis', forward_basis.float().to(device))
|
75 |
+
self.register_buffer('inverse_basis', inverse_basis.float().to(device))
|
76 |
+
|
77 |
+
def transform(self, input_data):
|
78 |
+
num_batches = input_data.size(0)
|
79 |
+
num_samples = input_data.size(1)
|
80 |
+
|
81 |
+
self.num_samples = num_samples
|
82 |
+
|
83 |
+
# similar to librosa, reflect-pad the input
|
84 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
85 |
+
input_data = F.pad(
|
86 |
+
input_data.unsqueeze(1),
|
87 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
88 |
+
mode='reflect')
|
89 |
+
input_data = input_data.squeeze(1)
|
90 |
+
# print(self.forward_basis.device)
|
91 |
+
forward_transform = F.conv1d(
|
92 |
+
input_data,
|
93 |
+
Variable(self.forward_basis, requires_grad=False),
|
94 |
+
stride=self.hop_length,
|
95 |
+
padding=0)
|
96 |
+
|
97 |
+
cutoff = int((self.filter_length / 2) + 1)
|
98 |
+
real_part = forward_transform[:, :cutoff, :]
|
99 |
+
imag_part = forward_transform[:, cutoff:, :]
|
100 |
+
|
101 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
102 |
+
phase = torch.autograd.Variable(
|
103 |
+
torch.atan2(imag_part.data, real_part.data))
|
104 |
+
|
105 |
+
return magnitude, phase
|
106 |
+
|
107 |
+
def inverse(self, magnitude, phase):
|
108 |
+
recombine_magnitude_phase = torch.cat(
|
109 |
+
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
|
110 |
+
|
111 |
+
with torch.no_grad():
|
112 |
+
inverse_transform = F.conv_transpose1d(
|
113 |
+
recombine_magnitude_phase, self.inverse_basis,
|
114 |
+
stride=self.hop_length, padding=0)
|
115 |
+
|
116 |
+
if self.window is not None:
|
117 |
+
window_sum = window_sumsquare(
|
118 |
+
self.window, magnitude.size(-1), hop_length=self.hop_length,
|
119 |
+
win_length=self.win_length, n_fft=self.filter_length,
|
120 |
+
dtype=np.float32)
|
121 |
+
# remove modulation effects
|
122 |
+
approx_nonzero_indices = torch.from_numpy(
|
123 |
+
np.where(window_sum > tiny(window_sum))[0])
|
124 |
+
window_sum = torch.autograd.Variable(
|
125 |
+
torch.from_numpy(window_sum), requires_grad=False)
|
126 |
+
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
|
127 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
128 |
+
|
129 |
+
# scale by hop ratio
|
130 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
131 |
+
|
132 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
|
133 |
+
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
|
134 |
+
|
135 |
+
return inverse_transform
|
136 |
+
|
137 |
+
def forward(self, input_data):
|
138 |
+
self.magnitude, self.phase = self.transform(input_data)
|
139 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
140 |
+
return reconstruction
|
common/tb_dllogger.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import atexit
|
2 |
+
import glob
|
3 |
+
import re
|
4 |
+
from itertools import product
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import dllogger
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
|
11 |
+
from torch.utils.tensorboard import SummaryWriter
|
12 |
+
|
13 |
+
|
14 |
+
tb_loggers = {}
|
15 |
+
|
16 |
+
|
17 |
+
class TBLogger:
|
18 |
+
"""
|
19 |
+
xyz_dummies: stretch the screen with empty plots so the legend would
|
20 |
+
always fit for other plots
|
21 |
+
"""
|
22 |
+
def __init__(self, enabled, log_dir, name, interval=1, dummies=True):
|
23 |
+
self.enabled = enabled
|
24 |
+
self.interval = interval
|
25 |
+
self.cache = {}
|
26 |
+
if self.enabled:
|
27 |
+
self.summary_writer = SummaryWriter(
|
28 |
+
log_dir=Path(log_dir, name), flush_secs=120, max_queue=200)
|
29 |
+
atexit.register(self.summary_writer.close)
|
30 |
+
if dummies:
|
31 |
+
for key in ('_', '✕'):
|
32 |
+
self.summary_writer.add_scalar(key, 0.0, 1)
|
33 |
+
|
34 |
+
def log(self, step, data):
|
35 |
+
for k, v in data.items():
|
36 |
+
self.log_value(step, k, v.item() if type(v) is torch.Tensor else v)
|
37 |
+
|
38 |
+
def log_value(self, step, key, val, stat='mean'):
|
39 |
+
if self.enabled:
|
40 |
+
if key not in self.cache:
|
41 |
+
self.cache[key] = []
|
42 |
+
self.cache[key].append(val)
|
43 |
+
if len(self.cache[key]) == self.interval:
|
44 |
+
agg_val = getattr(np, stat)(self.cache[key])
|
45 |
+
self.summary_writer.add_scalar(key, agg_val, step)
|
46 |
+
del self.cache[key]
|
47 |
+
|
48 |
+
def log_grads(self, step, model):
|
49 |
+
if self.enabled:
|
50 |
+
norms = [p.grad.norm().item() for p in model.parameters()
|
51 |
+
if p.grad is not None]
|
52 |
+
for stat in ('max', 'min', 'mean'):
|
53 |
+
self.log_value(step, f'grad_{stat}', getattr(np, stat)(norms),
|
54 |
+
stat=stat)
|
55 |
+
|
56 |
+
|
57 |
+
def unique_log_fpath(fpath):
|
58 |
+
|
59 |
+
if not Path(fpath).is_file():
|
60 |
+
return fpath
|
61 |
+
|
62 |
+
# Avoid overwriting old logs
|
63 |
+
saved = [re.search('\.(\d+)$', f) for f in glob.glob(f'{fpath}.*')]
|
64 |
+
saved = [0] + [int(m.group(1)) for m in saved if m is not None]
|
65 |
+
return f'{fpath}.{max(saved) + 1}'
|
66 |
+
|
67 |
+
|
68 |
+
def stdout_step_format(step):
|
69 |
+
if isinstance(step, str):
|
70 |
+
return step
|
71 |
+
fields = []
|
72 |
+
if len(step) > 0:
|
73 |
+
fields.append("epoch {:>4}".format(step[0]))
|
74 |
+
if len(step) > 1:
|
75 |
+
fields.append("iter {:>3}".format(step[1]))
|
76 |
+
if len(step) > 2:
|
77 |
+
fields[-1] += "/{}".format(step[2])
|
78 |
+
return " | ".join(fields)
|
79 |
+
|
80 |
+
|
81 |
+
def stdout_metric_format(metric, metadata, value):
|
82 |
+
name = metadata.get("name", metric + " : ")
|
83 |
+
unit = metadata.get("unit", None)
|
84 |
+
format = f'{{{metadata.get("format", "")}}}'
|
85 |
+
fields = [name, format.format(value) if value is not None else value, unit]
|
86 |
+
fields = [f for f in fields if f is not None]
|
87 |
+
return "| " + " ".join(fields)
|
88 |
+
|
89 |
+
|
90 |
+
def init(log_fpath, log_dir, enabled=True, tb_subsets=[], **tb_kw):
|
91 |
+
|
92 |
+
if enabled:
|
93 |
+
backends = [JSONStreamBackend(Verbosity.DEFAULT,
|
94 |
+
unique_log_fpath(log_fpath)),
|
95 |
+
StdOutBackend(Verbosity.VERBOSE,
|
96 |
+
step_format=stdout_step_format,
|
97 |
+
metric_format=stdout_metric_format)]
|
98 |
+
else:
|
99 |
+
backends = []
|
100 |
+
|
101 |
+
dllogger.init(backends=backends)
|
102 |
+
dllogger.metadata("train_lrate", {"name": "lrate", "unit": None, "format": ":>3.2e"})
|
103 |
+
|
104 |
+
for id_, pref in [('train', ''), ('train_avg', 'avg train '),
|
105 |
+
('val', ' avg val '), ('val_ema', ' EMA val ')]:
|
106 |
+
|
107 |
+
dllogger.metadata(f"{id_}_loss",
|
108 |
+
{"name": f"{pref}loss", "unit": None, "format": ":>5.2f"})
|
109 |
+
dllogger.metadata(f"{id_}_mel_loss",
|
110 |
+
{"name": f"{pref}mel loss", "unit": None, "format": ":>5.2f"})
|
111 |
+
|
112 |
+
dllogger.metadata(f"{id_}_kl_loss",
|
113 |
+
{"name": f"{pref}kl loss", "unit": None, "format": ":>5.5f"})
|
114 |
+
dllogger.metadata(f"{id_}_kl_weight",
|
115 |
+
{"name": f"{pref}kl weight", "unit": None, "format": ":>5.5f"})
|
116 |
+
|
117 |
+
dllogger.metadata(f"{id_}_frames/s",
|
118 |
+
{"name": None, "unit": "frames/s", "format": ":>10.2f"})
|
119 |
+
dllogger.metadata(f"{id_}_took",
|
120 |
+
{"name": "took", "unit": "s", "format": ":>3.2f"})
|
121 |
+
|
122 |
+
global tb_loggers
|
123 |
+
tb_loggers = {s: TBLogger(enabled, log_dir, name=s, **tb_kw)
|
124 |
+
for s in tb_subsets}
|
125 |
+
|
126 |
+
|
127 |
+
def init_inference_metadata(batch_size=None):
|
128 |
+
|
129 |
+
modalities = [('latency', 's', ':>10.5f'), ('RTF', 'x', ':>10.2f'),
|
130 |
+
('frames/s', 'frames/s', ':>10.2f'), ('samples/s', 'samples/s', ':>10.2f'),
|
131 |
+
('letters/s', 'letters/s', ':>10.2f'), ('tokens/s', 'tokens/s', ':>10.2f')]
|
132 |
+
|
133 |
+
if batch_size is not None:
|
134 |
+
modalities.append((f'RTF@{batch_size}', 'x', ':>10.2f'))
|
135 |
+
|
136 |
+
percs = ['', 'avg', '90%', '95%', '99%']
|
137 |
+
models = ['', 'fastpitch', 'waveglow', 'hifigan']
|
138 |
+
|
139 |
+
for perc, model, (mod, unit, fmt) in product(percs, models, modalities):
|
140 |
+
name = f'{perc} {model} {mod}'.strip().replace(' ', ' ')
|
141 |
+
dllogger.metadata(name.replace(' ', '_'),
|
142 |
+
{'name': f'{name: <26}', 'unit': unit, 'format': fmt})
|
143 |
+
|
144 |
+
|
145 |
+
def log(step, tb_total_steps=None, data={}, subset='train'):
|
146 |
+
if tb_total_steps is not None:
|
147 |
+
tb_loggers[subset].log(tb_total_steps, data)
|
148 |
+
|
149 |
+
if subset != '':
|
150 |
+
data = {f'{subset}_{key}': v for key, v in data.items()}
|
151 |
+
dllogger.log(step, data=data)
|
152 |
+
|
153 |
+
|
154 |
+
def log_grads_tb(tb_total_steps, grads, tb_subset='train'):
|
155 |
+
tb_loggers[tb_subset].log_grads(tb_total_steps, grads)
|
156 |
+
|
157 |
+
|
158 |
+
def parameters(data, verbosity=0, tb_subset=None):
|
159 |
+
for k, v in data.items():
|
160 |
+
dllogger.log(step="PARAMETER", data={k: v}, verbosity=verbosity)
|
161 |
+
|
162 |
+
if tb_subset is not None and tb_loggers[tb_subset].enabled:
|
163 |
+
tb_data = {k: v for k, v in data.items()
|
164 |
+
if type(v) in (str, bool, int, float)}
|
165 |
+
tb_loggers[tb_subset].summary_writer.add_hparams(tb_data, {})
|
166 |
+
|
167 |
+
|
168 |
+
def flush():
|
169 |
+
dllogger.flush()
|
170 |
+
for tbl in tb_loggers.values():
|
171 |
+
if tbl.enabled:
|
172 |
+
tbl.summary_writer.flush()
|
common/text/LICENSE
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2017 Keith Ito
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
of this software and associated documentation files (the "Software"), to deal
|
5 |
+
in the Software without restriction, including without limitation the rights
|
6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
copies of the Software, and to permit persons to whom the Software is
|
8 |
+
furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in
|
11 |
+
all copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
19 |
+
THE SOFTWARE.
|
common/text/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .cmudict import CMUDict
|
2 |
+
|
3 |
+
cmudict = CMUDict()
|
common/text/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (203 Bytes). View file
|
|
common/text/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (230 Bytes). View file
|
|
common/text/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (230 Bytes). View file
|
|
common/text/__pycache__/abbreviations.cpython-37.pyc
ADDED
Binary file (1.84 kB). View file
|
|
common/text/__pycache__/abbreviations.cpython-38.pyc
ADDED
Binary file (1.87 kB). View file
|
|
common/text/__pycache__/abbreviations.cpython-39.pyc
ADDED
Binary file (1.87 kB). View file
|
|
common/text/__pycache__/acronyms.cpython-37.pyc
ADDED
Binary file (2.55 kB). View file
|
|
common/text/__pycache__/acronyms.cpython-38.pyc
ADDED
Binary file (2.68 kB). View file
|
|
common/text/__pycache__/acronyms.cpython-39.pyc
ADDED
Binary file (2.54 kB). View file
|
|
common/text/__pycache__/cleaners.cpython-37.pyc
ADDED
Binary file (2.7 kB). View file
|
|
common/text/__pycache__/cleaners.cpython-38.pyc
ADDED
Binary file (2.77 kB). View file
|
|
common/text/__pycache__/cleaners.cpython-39.pyc
ADDED
Binary file (2.75 kB). View file
|
|
common/text/__pycache__/cmudict.cpython-37.pyc
ADDED
Binary file (3.79 kB). View file
|
|
common/text/__pycache__/cmudict.cpython-38.pyc
ADDED
Binary file (3.99 kB). View file
|
|
common/text/__pycache__/cmudict.cpython-39.pyc
ADDED
Binary file (3.69 kB). View file
|
|
common/text/__pycache__/datestime.cpython-37.pyc
ADDED
Binary file (720 Bytes). View file
|
|
common/text/__pycache__/datestime.cpython-38.pyc
ADDED
Binary file (757 Bytes). View file
|
|
common/text/__pycache__/datestime.cpython-39.pyc
ADDED
Binary file (757 Bytes). View file
|
|
common/text/__pycache__/letters_and_numbers.cpython-37.pyc
ADDED
Binary file (2.88 kB). View file
|
|
common/text/__pycache__/letters_and_numbers.cpython-38.pyc
ADDED
Binary file (2.94 kB). View file
|
|
common/text/__pycache__/letters_and_numbers.cpython-39.pyc
ADDED
Binary file (2.92 kB). View file
|
|
common/text/__pycache__/numerical.cpython-37.pyc
ADDED
Binary file (4.65 kB). View file
|
|
common/text/__pycache__/numerical.cpython-38.pyc
ADDED
Binary file (4.7 kB). View file
|
|
common/text/__pycache__/numerical.cpython-39.pyc
ADDED
Binary file (4.71 kB). View file
|
|
common/text/__pycache__/symbols.cpython-37.pyc
ADDED
Binary file (1.5 kB). View file
|
|
common/text/__pycache__/symbols.cpython-38.pyc
ADDED
Binary file (1.96 kB). View file
|
|
common/text/__pycache__/symbols.cpython-39.pyc
ADDED
Binary file (2.16 kB). View file
|
|
common/text/__pycache__/text_processing.cpython-37.pyc
ADDED
Binary file (4.96 kB). View file
|
|
common/text/__pycache__/text_processing.cpython-38.pyc
ADDED
Binary file (5.04 kB). View file
|
|
common/text/__pycache__/text_processing.cpython-39.pyc
ADDED
Binary file (4.36 kB). View file
|
|
common/text/abbreviations.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
_no_period_re = re.compile(r'(No[.])(?=[ ]?[0-9])')
|
4 |
+
_percent_re = re.compile(r'([ ]?[%])')
|
5 |
+
_half_re = re.compile('([0-9]½)|(½)')
|
6 |
+
_url_re = re.compile(r'([a-zA-Z])\.(com|gov|org)')
|
7 |
+
|
8 |
+
|
9 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
10 |
+
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
11 |
+
('mrs', 'misess'),
|
12 |
+
('ms', 'miss'),
|
13 |
+
('mr', 'mister'),
|
14 |
+
('dr', 'doctor'),
|
15 |
+
('st', 'saint'),
|
16 |
+
('co', 'company'),
|
17 |
+
('jr', 'junior'),
|
18 |
+
('maj', 'major'),
|
19 |
+
('gen', 'general'),
|
20 |
+
('drs', 'doctors'),
|
21 |
+
('rev', 'reverend'),
|
22 |
+
('lt', 'lieutenant'),
|
23 |
+
('hon', 'honorable'),
|
24 |
+
('sgt', 'sergeant'),
|
25 |
+
('capt', 'captain'),
|
26 |
+
('esq', 'esquire'),
|
27 |
+
('ltd', 'limited'),
|
28 |
+
('col', 'colonel'),
|
29 |
+
('ft', 'fort'),
|
30 |
+
('sen', 'senator'),
|
31 |
+
('etc', 'et cetera'),
|
32 |
+
]]
|
33 |
+
|
34 |
+
|
35 |
+
def _expand_no_period(m):
|
36 |
+
word = m.group(0)
|
37 |
+
if word[0] == 'N':
|
38 |
+
return 'Number'
|
39 |
+
return 'number'
|
40 |
+
|
41 |
+
|
42 |
+
def _expand_percent(m):
|
43 |
+
return ' percent'
|
44 |
+
|
45 |
+
|
46 |
+
def _expand_half(m):
|
47 |
+
word = m.group(1)
|
48 |
+
if word is None:
|
49 |
+
return 'half'
|
50 |
+
return word[0] + ' and a half'
|
51 |
+
|
52 |
+
|
53 |
+
def _expand_urls(m):
|
54 |
+
return f'{m.group(1)} dot {m.group(2)}'
|
55 |
+
|
56 |
+
|
57 |
+
def normalize_abbreviations(text):
|
58 |
+
text = re.sub(_no_period_re, _expand_no_period, text)
|
59 |
+
text = re.sub(_percent_re, _expand_percent, text)
|
60 |
+
text = re.sub(_half_re, _expand_half, text)
|
61 |
+
text = re.sub('&', ' and ', text)
|
62 |
+
text = re.sub('@', ' at ', text)
|
63 |
+
text = re.sub(_url_re, _expand_urls, text)
|
64 |
+
|
65 |
+
for regex, replacement in _abbreviations:
|
66 |
+
text = re.sub(regex, replacement, text)
|
67 |
+
return text
|
common/text/acronyms.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from . import cmudict
|
3 |
+
|
4 |
+
_letter_to_arpabet = {
|
5 |
+
'A': 'EY1',
|
6 |
+
'B': 'B IY1',
|
7 |
+
'C': 'S IY1',
|
8 |
+
'D': 'D IY1',
|
9 |
+
'E': 'IY1',
|
10 |
+
'F': 'EH1 F',
|
11 |
+
'G': 'JH IY1',
|
12 |
+
'H': 'EY1 CH',
|
13 |
+
'I': 'AY1',
|
14 |
+
'J': 'JH EY1',
|
15 |
+
'K': 'K EY1',
|
16 |
+
'L': 'EH1 L',
|
17 |
+
'M': 'EH1 M',
|
18 |
+
'N': 'EH1 N',
|
19 |
+
'O': 'OW1',
|
20 |
+
'P': 'P IY1',
|
21 |
+
'Q': 'K Y UW1',
|
22 |
+
'R': 'AA1 R',
|
23 |
+
'S': 'EH1 S',
|
24 |
+
'T': 'T IY1',
|
25 |
+
'U': 'Y UW1',
|
26 |
+
'V': 'V IY1',
|
27 |
+
'X': 'EH1 K S',
|
28 |
+
'Y': 'W AY1',
|
29 |
+
'W': 'D AH1 B AH0 L Y UW0',
|
30 |
+
'Z': 'Z IY1',
|
31 |
+
's': 'Z'
|
32 |
+
}
|
33 |
+
|
34 |
+
# Acronyms that should not be expanded
|
35 |
+
hardcoded_acronyms = [
|
36 |
+
'BMW', 'MVD', 'WDSU', 'GOP', 'UK', 'AI', 'GPS', 'BP', 'FBI', 'HD',
|
37 |
+
'CES', 'LRA', 'PC', 'NBA', 'BBL', 'OS', 'IRS', 'SAC', 'UV', 'CEO', 'TV',
|
38 |
+
'CNN', 'MSS', 'GSA', 'USSR', 'DNA', 'PRS', 'TSA', 'US', 'GPU', 'USA',
|
39 |
+
'FPCC', 'CIA']
|
40 |
+
|
41 |
+
# Words and acronyms that should be read as regular words, e.g., NATO, HAPPY, etc.
|
42 |
+
uppercase_whiteliset = []
|
43 |
+
|
44 |
+
acronyms_exceptions = {
|
45 |
+
'NVIDIA': 'N.VIDIA',
|
46 |
+
}
|
47 |
+
|
48 |
+
non_uppercase_exceptions = {
|
49 |
+
'email': 'e-mail',
|
50 |
+
}
|
51 |
+
|
52 |
+
# must ignore roman numerals
|
53 |
+
_acronym_re = re.compile(r'([a-z]*[A-Z][A-Z]+)s?\.?')
|
54 |
+
_non_uppercase_re = re.compile(r'\b({})\b'.format('|'.join(non_uppercase_exceptions.keys())), re.IGNORECASE)
|
55 |
+
|
56 |
+
|
57 |
+
def _expand_acronyms_to_arpa(m, add_spaces=True):
|
58 |
+
acronym = m.group(0)
|
59 |
+
|
60 |
+
# remove dots if they exist
|
61 |
+
acronym = re.sub('\.', '', acronym)
|
62 |
+
|
63 |
+
acronym = "".join(acronym.split())
|
64 |
+
arpabet = cmudict.lookup(acronym)
|
65 |
+
|
66 |
+
if arpabet is None:
|
67 |
+
acronym = list(acronym)
|
68 |
+
arpabet = ["{" + _letter_to_arpabet[letter] + "}" for letter in acronym]
|
69 |
+
# temporary fix
|
70 |
+
if arpabet[-1] == '{Z}' and len(arpabet) > 1:
|
71 |
+
arpabet[-2] = arpabet[-2][:-1] + ' ' + arpabet[-1][1:]
|
72 |
+
del arpabet[-1]
|
73 |
+
|
74 |
+
arpabet = ' '.join(arpabet)
|
75 |
+
elif len(arpabet) == 1:
|
76 |
+
arpabet = "{" + arpabet[0] + "}"
|
77 |
+
else:
|
78 |
+
arpabet = acronym
|
79 |
+
|
80 |
+
return arpabet
|
81 |
+
|
82 |
+
|
83 |
+
def normalize_acronyms(text):
|
84 |
+
text = re.sub(_acronym_re, _expand_acronyms_to_arpa, text)
|
85 |
+
return text
|
86 |
+
|
87 |
+
|
88 |
+
def expand_acronyms(m):
|
89 |
+
text = m.group(1)
|
90 |
+
if text in acronyms_exceptions:
|
91 |
+
text = acronyms_exceptions[text]
|
92 |
+
elif text in uppercase_whiteliset:
|
93 |
+
text = text
|
94 |
+
else:
|
95 |
+
text = '.'.join(text) + '.'
|
96 |
+
|
97 |
+
if 's' in m.group(0):
|
98 |
+
text = text + '\'s'
|
99 |
+
|
100 |
+
if text[-1] != '.' and m.group(0)[-1] == '.':
|
101 |
+
return text + '.'
|
102 |
+
else:
|
103 |
+
return text
|
104 |
+
|
105 |
+
|
106 |
+
def spell_acronyms(text):
|
107 |
+
text = re.sub(_non_uppercase_re, lambda m: non_uppercase_exceptions[m.group(0).lower()], text)
|
108 |
+
text = re.sub(_acronym_re, expand_acronyms, text)
|
109 |
+
return text
|
common/text/cleaners.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" adapted from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
'''
|
4 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
5 |
+
|
6 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
7 |
+
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
8 |
+
1. "english_cleaners" for English text
|
9 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
10 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
11 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
12 |
+
the symbols in symbols.py to match your data).
|
13 |
+
'''
|
14 |
+
|
15 |
+
import re
|
16 |
+
from .abbreviations import normalize_abbreviations
|
17 |
+
from .acronyms import normalize_acronyms, spell_acronyms
|
18 |
+
from .datestime import normalize_datestime
|
19 |
+
from .letters_and_numbers import normalize_letters_and_numbers
|
20 |
+
from .numerical import normalize_numbers
|
21 |
+
from .unidecoder import unidecoder
|
22 |
+
|
23 |
+
|
24 |
+
# Regular expression matching whitespace:
|
25 |
+
_whitespace_re = re.compile(r'\s+')
|
26 |
+
|
27 |
+
|
28 |
+
def expand_abbreviations(text):
|
29 |
+
return normalize_abbreviations(text)
|
30 |
+
|
31 |
+
|
32 |
+
def expand_numbers(text):
|
33 |
+
return normalize_numbers(text)
|
34 |
+
|
35 |
+
|
36 |
+
def expand_acronyms(text):
|
37 |
+
return normalize_acronyms(text)
|
38 |
+
|
39 |
+
|
40 |
+
def expand_datestime(text):
|
41 |
+
return normalize_datestime(text)
|
42 |
+
|
43 |
+
|
44 |
+
def expand_letters_and_numbers(text):
|
45 |
+
return normalize_letters_and_numbers(text)
|
46 |
+
|
47 |
+
|
48 |
+
def lowercase(text):
|
49 |
+
return text.lower()
|
50 |
+
|
51 |
+
|
52 |
+
def collapse_whitespace(text):
|
53 |
+
return re.sub(_whitespace_re, ' ', text)
|
54 |
+
|
55 |
+
|
56 |
+
def separate_acronyms(text):
|
57 |
+
text = re.sub(r"([0-9]+)([a-zA-Z]+)", r"\1 \2", text)
|
58 |
+
text = re.sub(r"([a-zA-Z]+)([0-9]+)", r"\1 \2", text)
|
59 |
+
return text
|
60 |
+
|
61 |
+
|
62 |
+
def convert_to_ascii(text):
|
63 |
+
return unidecoder(text)
|
64 |
+
|
65 |
+
|
66 |
+
def basic_cleaners(text):
|
67 |
+
'''Basic pipeline that collapses whitespace without transliteration.'''
|
68 |
+
# text = lowercase(text)
|
69 |
+
text = collapse_whitespace(text)
|
70 |
+
return text
|
71 |
+
|
72 |
+
|
73 |
+
def transliteration_cleaners(text):
|
74 |
+
'''Pipeline for non-English text that transliterates to ASCII.'''
|
75 |
+
text = convert_to_ascii(text)
|
76 |
+
text = lowercase(text)
|
77 |
+
text = collapse_whitespace(text)
|
78 |
+
return text
|
79 |
+
|
80 |
+
|
81 |
+
def english_cleaners(text):
|
82 |
+
'''Pipeline for English text, with number and abbreviation expansion.'''
|
83 |
+
text = convert_to_ascii(text)
|
84 |
+
text = lowercase(text)
|
85 |
+
text = expand_numbers(text)
|
86 |
+
text = expand_abbreviations(text)
|
87 |
+
text = collapse_whitespace(text)
|
88 |
+
return text
|
89 |
+
|
90 |
+
|
91 |
+
def english_cleaners_v2(text):
|
92 |
+
text = convert_to_ascii(text)
|
93 |
+
text = expand_datestime(text)
|
94 |
+
text = expand_letters_and_numbers(text)
|
95 |
+
text = expand_numbers(text)
|
96 |
+
text = expand_abbreviations(text)
|
97 |
+
text = spell_acronyms(text)
|
98 |
+
text = lowercase(text)
|
99 |
+
text = collapse_whitespace(text)
|
100 |
+
# compatibility with basic_english symbol set
|
101 |
+
text = re.sub(r'/+', ' ', text)
|
102 |
+
return text
|
common/text/cmudict.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
import urllib.request
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
|
9 |
+
valid_symbols = [
|
10 |
+
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
|
11 |
+
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
|
12 |
+
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
|
13 |
+
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
|
14 |
+
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
|
15 |
+
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
|
16 |
+
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
|
17 |
+
]
|
18 |
+
|
19 |
+
_valid_symbol_set = set(valid_symbols)
|
20 |
+
|
21 |
+
|
22 |
+
class CMUDict:
|
23 |
+
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
|
24 |
+
def __init__(self, file_or_path=None, heteronyms_path=None, keep_ambiguous=True):
|
25 |
+
self._entries = {}
|
26 |
+
self.heteronyms = []
|
27 |
+
if file_or_path is not None:
|
28 |
+
self.initialize(file_or_path, heteronyms_path, keep_ambiguous)
|
29 |
+
|
30 |
+
def initialize(self, file_or_path, heteronyms_path, keep_ambiguous=True):
|
31 |
+
if isinstance(file_or_path, str):
|
32 |
+
if not Path(file_or_path).exists():
|
33 |
+
print("CMUdict missing. Downloading to data/cmudict/.")
|
34 |
+
self.download()
|
35 |
+
|
36 |
+
with open(file_or_path, encoding='latin-1') as f:
|
37 |
+
entries = _parse_cmudict(f)
|
38 |
+
|
39 |
+
else:
|
40 |
+
entries = _parse_cmudict(file_or_path)
|
41 |
+
|
42 |
+
if not keep_ambiguous:
|
43 |
+
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
44 |
+
self._entries = entries
|
45 |
+
|
46 |
+
if heteronyms_path is not None:
|
47 |
+
with open(heteronyms_path, encoding='utf-8') as f:
|
48 |
+
self.heteronyms = [l.rstrip() for l in f]
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
if len(self._entries) == 0:
|
52 |
+
raise ValueError("CMUDict not initialized")
|
53 |
+
return len(self._entries)
|
54 |
+
|
55 |
+
def lookup(self, word):
|
56 |
+
'''Returns list of ARPAbet pronunciations of the given word.'''
|
57 |
+
if len(self._entries) == 0:
|
58 |
+
raise ValueError("CMUDict not initialized")
|
59 |
+
return self._entries.get(word.upper())
|
60 |
+
|
61 |
+
def download(self):
|
62 |
+
url = 'https://github.com/Alexir/CMUdict/raw/master/cmudict-0.7b'
|
63 |
+
try:
|
64 |
+
Path('cmudict').mkdir(parents=False, exist_ok=True)
|
65 |
+
urllib.request.urlretrieve(url, filename='cmudict/cmudict-0.7b')
|
66 |
+
except:
|
67 |
+
print("Automatic download of CMUdict failed. Try manually with:")
|
68 |
+
print()
|
69 |
+
print(" bash scripts/download_cmudict.sh")
|
70 |
+
print()
|
71 |
+
print("and re-run the script.")
|
72 |
+
sys.exit(0)
|
73 |
+
|
74 |
+
|
75 |
+
_alt_re = re.compile(r'\([0-9]+\)')
|
76 |
+
|
77 |
+
|
78 |
+
def _parse_cmudict(file):
|
79 |
+
cmudict = {}
|
80 |
+
for line in file:
|
81 |
+
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
|
82 |
+
parts = line.split(' ')
|
83 |
+
word = re.sub(_alt_re, '', parts[0])
|
84 |
+
pronunciation = _get_pronunciation(parts[1])
|
85 |
+
if pronunciation:
|
86 |
+
if word in cmudict:
|
87 |
+
cmudict[word].append(pronunciation)
|
88 |
+
else:
|
89 |
+
cmudict[word] = [pronunciation]
|
90 |
+
return cmudict
|
91 |
+
|
92 |
+
|
93 |
+
def _get_pronunciation(s):
|
94 |
+
parts = s.strip().split(' ')
|
95 |
+
for part in parts:
|
96 |
+
if part not in _valid_symbol_set:
|
97 |
+
return None
|
98 |
+
return ' '.join(parts)
|
common/text/datestime.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
_ampm_re = re.compile(
|
3 |
+
r'([0-9]|0[0-9]|1[0-9]|2[0-3]):?([0-5][0-9])?\s*([AaPp][Mm]\b)')
|
4 |
+
|
5 |
+
|
6 |
+
def _expand_ampm(m):
|
7 |
+
matches = list(m.groups(0))
|
8 |
+
txt = matches[0]
|
9 |
+
txt = txt if int(matches[1]) == 0 else txt + ' ' + matches[1]
|
10 |
+
|
11 |
+
if matches[2][0].lower() == 'a':
|
12 |
+
txt += ' a.m.'
|
13 |
+
elif matches[2][0].lower() == 'p':
|
14 |
+
txt += ' p.m.'
|
15 |
+
|
16 |
+
return txt
|
17 |
+
|
18 |
+
|
19 |
+
def normalize_datestime(text):
|
20 |
+
text = re.sub(_ampm_re, _expand_ampm, text)
|
21 |
+
#text = re.sub(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])?", r"\1 \2", text)
|
22 |
+
return text
|
common/text/letters_and_numbers.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
_letters_and_numbers_re = re.compile(
|
3 |
+
r"((?:[a-zA-Z]+[0-9]|[0-9]+[a-zA-Z])[a-zA-Z0-9']*)", re.IGNORECASE)
|
4 |
+
|
5 |
+
_hardware_re = re.compile(
|
6 |
+
'([0-9]+(?:[.,][0-9]+)?)(?:\s?)(tb|gb|mb|kb|ghz|mhz|khz|hz|mm)', re.IGNORECASE)
|
7 |
+
_hardware_key = {'tb': 'terabyte',
|
8 |
+
'gb': 'gigabyte',
|
9 |
+
'mb': 'megabyte',
|
10 |
+
'kb': 'kilobyte',
|
11 |
+
'ghz': 'gigahertz',
|
12 |
+
'mhz': 'megahertz',
|
13 |
+
'khz': 'kilohertz',
|
14 |
+
'hz': 'hertz',
|
15 |
+
'mm': 'millimeter',
|
16 |
+
'cm': 'centimeter',
|
17 |
+
'km': 'kilometer'}
|
18 |
+
|
19 |
+
_dimension_re = re.compile(
|
20 |
+
r'\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b|\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b')
|
21 |
+
_dimension_key = {'m': 'meter',
|
22 |
+
'in': 'inch',
|
23 |
+
'inch': 'inch'}
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
def _expand_letters_and_numbers(m):
|
29 |
+
text = re.split(r'(\d+)', m.group(0))
|
30 |
+
|
31 |
+
# remove trailing space
|
32 |
+
if text[-1] == '':
|
33 |
+
text = text[:-1]
|
34 |
+
elif text[0] == '':
|
35 |
+
text = text[1:]
|
36 |
+
|
37 |
+
# if not like 1920s, or AK47's , 20th, 1st, 2nd, 3rd, etc...
|
38 |
+
if text[-1] in ("'s", "s", "th", "nd", "st", "rd") and text[-2].isdigit():
|
39 |
+
text[-2] = text[-2] + text[-1]
|
40 |
+
text = text[:-1]
|
41 |
+
|
42 |
+
# for combining digits 2 by 2
|
43 |
+
new_text = []
|
44 |
+
for i in range(len(text)):
|
45 |
+
string = text[i]
|
46 |
+
if string.isdigit() and len(string) < 5:
|
47 |
+
# heuristics
|
48 |
+
if len(string) > 2 and string[-2] == '0':
|
49 |
+
if string[-1] == '0':
|
50 |
+
string = [string]
|
51 |
+
else:
|
52 |
+
string = [string[:-2], string[-2], string[-1]]
|
53 |
+
elif len(string) % 2 == 0:
|
54 |
+
string = [string[i:i+2] for i in range(0, len(string), 2)]
|
55 |
+
elif len(string) > 2:
|
56 |
+
string = [string[0]] + [string[i:i+2] for i in range(1, len(string), 2)]
|
57 |
+
new_text.extend(string)
|
58 |
+
else:
|
59 |
+
new_text.append(string)
|
60 |
+
|
61 |
+
text = new_text
|
62 |
+
text = " ".join(text)
|
63 |
+
return text
|
64 |
+
|
65 |
+
|
66 |
+
def _expand_hardware(m):
|
67 |
+
quantity, measure = m.groups(0)
|
68 |
+
measure = _hardware_key[measure.lower()]
|
69 |
+
if measure[-1] != 'z' and float(quantity.replace(',', '')) > 1:
|
70 |
+
return "{} {}s".format(quantity, measure)
|
71 |
+
return "{} {}".format(quantity, measure)
|
72 |
+
|
73 |
+
|
74 |
+
def _expand_dimension(m):
|
75 |
+
text = "".join([x for x in m.groups(0) if x != 0])
|
76 |
+
text = text.replace(' x ', ' by ')
|
77 |
+
text = text.replace('x', ' by ')
|
78 |
+
if text.endswith(tuple(_dimension_key.keys())):
|
79 |
+
if text[-2].isdigit():
|
80 |
+
text = "{} {}".format(text[:-1], _dimension_key[text[-1:]])
|
81 |
+
elif text[-3].isdigit():
|
82 |
+
text = "{} {}".format(text[:-2], _dimension_key[text[-2:]])
|
83 |
+
return text
|
84 |
+
|
85 |
+
|
86 |
+
def normalize_letters_and_numbers(text):
|
87 |
+
text = re.sub(_hardware_re, _expand_hardware, text)
|
88 |
+
text = re.sub(_dimension_re, _expand_dimension, text)
|
89 |
+
text = re.sub(_letters_and_numbers_re, _expand_letters_and_numbers, text)
|
90 |
+
return text
|
common/text/numerical.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" adapted from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
import inflect
|
4 |
+
import re
|
5 |
+
_magnitudes = ['trillion', 'billion', 'million', 'thousand', 'hundred', 'm', 'b', 't']
|
6 |
+
_magnitudes_key = {'m': 'million', 'b': 'billion', 't': 'trillion'}
|
7 |
+
_measurements = '(f|c|k|d|m)'
|
8 |
+
_measurements_key = {'f': 'fahrenheit',
|
9 |
+
'c': 'celsius',
|
10 |
+
'k': 'thousand',
|
11 |
+
'm': 'meters'}
|
12 |
+
_currency_key = {'$': 'dollar', '£': 'pound', '€': 'euro', '₩': 'won'}
|
13 |
+
_inflect = inflect.engine()
|
14 |
+
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
15 |
+
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
16 |
+
_currency_re = re.compile(r'([\$€£₩])([0-9\.\,]*[0-9]+)(?:[ ]?({})(?=[^a-zA-Z]|$))?'.format("|".join(_magnitudes)), re.IGNORECASE)
|
17 |
+
_measurement_re = re.compile(r'([0-9\.\,]*[0-9]+(\s)?{}\b)'.format(_measurements), re.IGNORECASE)
|
18 |
+
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
19 |
+
# _range_re = re.compile(r'(?<=[0-9])+(-)(?=[0-9])+.*?')
|
20 |
+
_roman_re = re.compile(r'\b(?=[MDCLXVI]+\b)M{0,4}(CM|CD|D?C{0,3})(XC|XL|L?X{0,3})(IX|IV|V?I{2,3})\b') # avoid I
|
21 |
+
_multiply_re = re.compile(r'(\b[0-9]+)(x)([0-9]+)')
|
22 |
+
_number_re = re.compile(r"[0-9]+'s|[0-9]+s|[0-9]+")
|
23 |
+
|
24 |
+
def _remove_commas(m):
|
25 |
+
return m.group(1).replace(',', '')
|
26 |
+
|
27 |
+
|
28 |
+
def _expand_decimal_point(m):
|
29 |
+
return m.group(1).replace('.', ' point ')
|
30 |
+
|
31 |
+
|
32 |
+
def _expand_currency(m):
|
33 |
+
currency = _currency_key[m.group(1)]
|
34 |
+
quantity = m.group(2)
|
35 |
+
magnitude = m.group(3)
|
36 |
+
|
37 |
+
# remove commas from quantity to be able to convert to numerical
|
38 |
+
quantity = quantity.replace(',', '')
|
39 |
+
|
40 |
+
# check for million, billion, etc...
|
41 |
+
if magnitude is not None and magnitude.lower() in _magnitudes:
|
42 |
+
if len(magnitude) == 1:
|
43 |
+
magnitude = _magnitudes_key[magnitude.lower()]
|
44 |
+
return "{} {} {}".format(_expand_hundreds(quantity), magnitude, currency+'s')
|
45 |
+
|
46 |
+
parts = quantity.split('.')
|
47 |
+
if len(parts) > 2:
|
48 |
+
return quantity + " " + currency + "s" # Unexpected format
|
49 |
+
|
50 |
+
dollars = int(parts[0]) if parts[0] else 0
|
51 |
+
|
52 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
53 |
+
if dollars and cents:
|
54 |
+
dollar_unit = currency if dollars == 1 else currency+'s'
|
55 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
56 |
+
return "{} {}, {} {}".format(
|
57 |
+
_expand_hundreds(dollars), dollar_unit,
|
58 |
+
_inflect.number_to_words(cents), cent_unit)
|
59 |
+
elif dollars:
|
60 |
+
dollar_unit = currency if dollars == 1 else currency+'s'
|
61 |
+
return "{} {}".format(_expand_hundreds(dollars), dollar_unit)
|
62 |
+
elif cents:
|
63 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
64 |
+
return "{} {}".format(_inflect.number_to_words(cents), cent_unit)
|
65 |
+
else:
|
66 |
+
return 'zero' + ' ' + currency + 's'
|
67 |
+
|
68 |
+
|
69 |
+
def _expand_hundreds(text):
|
70 |
+
number = float(text)
|
71 |
+
if 1000 < number < 10000 and (number % 100 == 0) and (number % 1000 != 0):
|
72 |
+
return _inflect.number_to_words(int(number / 100)) + " hundred"
|
73 |
+
else:
|
74 |
+
return _inflect.number_to_words(text)
|
75 |
+
|
76 |
+
|
77 |
+
def _expand_ordinal(m):
|
78 |
+
return _inflect.number_to_words(m.group(0))
|
79 |
+
|
80 |
+
|
81 |
+
def _expand_measurement(m):
|
82 |
+
_, number, measurement = re.split('(\d+(?:\.\d+)?)', m.group(0))
|
83 |
+
number = _inflect.number_to_words(number)
|
84 |
+
measurement = "".join(measurement.split())
|
85 |
+
measurement = _measurements_key[measurement.lower()]
|
86 |
+
return "{} {}".format(number, measurement)
|
87 |
+
|
88 |
+
|
89 |
+
def _expand_range(m):
|
90 |
+
return ' to '
|
91 |
+
|
92 |
+
|
93 |
+
def _expand_multiply(m):
|
94 |
+
left = m.group(1)
|
95 |
+
right = m.group(3)
|
96 |
+
return "{} by {}".format(left, right)
|
97 |
+
|
98 |
+
|
99 |
+
def _expand_roman(m):
|
100 |
+
# from https://stackoverflow.com/questions/19308177/converting-roman-numerals-to-integers-in-python
|
101 |
+
roman_numerals = {'I':1, 'V':5, 'X':10, 'L':50, 'C':100, 'D':500, 'M':1000}
|
102 |
+
result = 0
|
103 |
+
num = m.group(0)
|
104 |
+
for i, c in enumerate(num):
|
105 |
+
if (i+1) == len(num) or roman_numerals[c] >= roman_numerals[num[i+1]]:
|
106 |
+
result += roman_numerals[c]
|
107 |
+
else:
|
108 |
+
result -= roman_numerals[c]
|
109 |
+
return str(result)
|
110 |
+
|
111 |
+
|
112 |
+
def _expand_number(m):
|
113 |
+
_, number, suffix = re.split(r"(\d+(?:'?\d+)?)", m.group(0))
|
114 |
+
number = int(number)
|
115 |
+
if number > 1000 < 10000 and (number % 100 == 0) and (number % 1000 != 0):
|
116 |
+
text = _inflect.number_to_words(number // 100) + " hundred"
|
117 |
+
elif number > 1000 and number < 3000:
|
118 |
+
if number == 2000:
|
119 |
+
text = 'two thousand'
|
120 |
+
elif number > 2000 and number < 2010:
|
121 |
+
text = 'two thousand ' + _inflect.number_to_words(number % 100)
|
122 |
+
elif number % 100 == 0:
|
123 |
+
text = _inflect.number_to_words(number // 100) + ' hundred'
|
124 |
+
else:
|
125 |
+
number = _inflect.number_to_words(number, andword='', zero='oh', group=2).replace(', ', ' ')
|
126 |
+
number = re.sub(r'-', ' ', number)
|
127 |
+
text = number
|
128 |
+
else:
|
129 |
+
number = _inflect.number_to_words(number, andword='and')
|
130 |
+
number = re.sub(r'-', ' ', number)
|
131 |
+
number = re.sub(r',', '', number)
|
132 |
+
text = number
|
133 |
+
|
134 |
+
if suffix in ("'s", "s"):
|
135 |
+
if text[-1] == 'y':
|
136 |
+
text = text[:-1] + 'ies'
|
137 |
+
else:
|
138 |
+
text = text + suffix
|
139 |
+
|
140 |
+
return text
|
141 |
+
|
142 |
+
|
143 |
+
def normalize_numbers(text):
|
144 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
145 |
+
text = re.sub(_currency_re, _expand_currency, text)
|
146 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
147 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
148 |
+
# text = re.sub(_range_re, _expand_range, text)
|
149 |
+
# text = re.sub(_measurement_re, _expand_measurement, text)
|
150 |
+
text = re.sub(_roman_re, _expand_roman, text)
|
151 |
+
text = re.sub(_multiply_re, _expand_multiply, text)
|
152 |
+
text = re.sub(_number_re, _expand_number, text)
|
153 |
+
return text
|
common/text/symbols.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
'''
|
4 |
+
Defines the set of symbols used in text input to the model.
|
5 |
+
|
6 |
+
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
7 |
+
from .cmudict import valid_symbols
|
8 |
+
|
9 |
+
|
10 |
+
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
11 |
+
_arpabet = ['@' + s for s in valid_symbols]
|
12 |
+
|
13 |
+
|
14 |
+
def get_symbols(symbol_set='english_basic'):
|
15 |
+
if symbol_set == 'english_basic':
|
16 |
+
_pad = '_'
|
17 |
+
_punctuation = '!\'(),.:;? '
|
18 |
+
_special = '-'
|
19 |
+
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
20 |
+
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
21 |
+
elif symbol_set == 'english_basic_lowercase':
|
22 |
+
_pad = '_'
|
23 |
+
_punctuation = '!\'"(),.:;? '
|
24 |
+
_special = '-'
|
25 |
+
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
26 |
+
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
27 |
+
elif symbol_set == 'english_expanded':
|
28 |
+
_punctuation = '!\'",.:;? '
|
29 |
+
_math = '#%&*+-/[]()'
|
30 |
+
_special = '_@©°½—₩€$'
|
31 |
+
_accented = 'áçéêëñöøćž'
|
32 |
+
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
33 |
+
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
34 |
+
elif symbol_set == 'smj_expanded':
|
35 |
+
_punctuation = '!\'",.:;?- '
|
36 |
+
_math = '#%&*+-/[]()'
|
37 |
+
_special = '_@©°½—₩€$'
|
38 |
+
# _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
39 |
+
_accented = 'áçéêëñöø' #also north sámi letters...
|
40 |
+
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
41 |
+
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTŦUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz' ########################## Ŧ ########################
|
42 |
+
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
43 |
+
symbols = list(_punctuation + _letters) + _arpabet
|
44 |
+
elif symbol_set == 'sme_expanded':
|
45 |
+
_punctuation = '!\'",.:;?- '
|
46 |
+
_math = '#%&*+-/[]()'
|
47 |
+
_special = '_@©°½—₩€$'
|
48 |
+
_accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
|
49 |
+
# _accented = 'áçéêëñöø' #also north sámi letters...
|
50 |
+
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
51 |
+
_letters = 'AÁÆÅÄBCČDĐEFGHIJKLMNŊOØÖPQRSŠTŦUVWXYZŽaáæåäbcčdđefghijklmnŋoøöpqrsštŧuvwxyzž'
|
52 |
+
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
53 |
+
symbols = list(_punctuation + _letters) + _arpabet
|
54 |
+
elif symbol_set == 'sma_expanded':
|
55 |
+
_punctuation = '!\'",.:;?- '
|
56 |
+
_math = '#%&*+-/[]()'
|
57 |
+
_special = '_@©°½—₩€$'
|
58 |
+
_accented = 'áäæçéêëïńñöøćčžđšŧ' #also north sámi letters...
|
59 |
+
# _accented = 'áçéêëñöø' #also north sámi letters...
|
60 |
+
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
61 |
+
_letters = 'AÆÅBCDEFGHIÏJKLMNOØÖPQRSTUVWXYZaæåbcdefghiïjklmnoøöpqrstuvwxyz'
|
62 |
+
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
63 |
+
symbols = list(_punctuation + _letters) + _arpabet
|
64 |
+
elif symbol_set == 'all_sami':
|
65 |
+
_punctuation = '!\'",.:;?- '
|
66 |
+
_math = '#%&*+-/[]()'
|
67 |
+
_special = '_@©°½—₩€$'
|
68 |
+
_accented = 'áäæçéêëïńñöøćčžđšŧ'
|
69 |
+
_letters = 'AÁÆÅÄBCČDĐEFGHIÏJKLMNŊŃÑOØÖPQRSŠTŦUVWXYZŽaáæåäbcčdđefghiïjklmnŋńñoøöpqrsštŧuvwxyzž'
|
70 |
+
symbols = list(_punctuation + _letters)# + _arpabet
|
71 |
+
else:
|
72 |
+
raise Exception("{} symbol set does not exist".format(symbol_set))
|
73 |
+
|
74 |
+
return symbols
|
75 |
+
|
76 |
+
|
77 |
+
def get_pad_idx(symbol_set='english_basic'):
|
78 |
+
if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded', 'sma_expanded', 'all_sami'}:
|
79 |
+
return 0
|
80 |
+
else:
|
81 |
+
raise Exception("{} symbol set not used yet".format(symbol_set))
|