kathiasi commited on
Commit
16f0ad7
·
verified ·
1 Parent(s): 83e9b55

Upload 100 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +6 -5
  2. app.py +73 -0
  3. common/audio_processing.py +120 -0
  4. common/env.py +25 -0
  5. common/filter_warnings.py +33 -0
  6. common/gpu_affinity.py +156 -0
  7. common/layers.py +134 -0
  8. common/repeated_dataloader.py +59 -0
  9. common/stft.py +140 -0
  10. common/tb_dllogger.py +172 -0
  11. common/text/LICENSE +19 -0
  12. common/text/__init__.py +3 -0
  13. common/text/__pycache__/__init__.cpython-37.pyc +0 -0
  14. common/text/__pycache__/__init__.cpython-38.pyc +0 -0
  15. common/text/__pycache__/__init__.cpython-39.pyc +0 -0
  16. common/text/__pycache__/abbreviations.cpython-37.pyc +0 -0
  17. common/text/__pycache__/abbreviations.cpython-38.pyc +0 -0
  18. common/text/__pycache__/abbreviations.cpython-39.pyc +0 -0
  19. common/text/__pycache__/acronyms.cpython-37.pyc +0 -0
  20. common/text/__pycache__/acronyms.cpython-38.pyc +0 -0
  21. common/text/__pycache__/acronyms.cpython-39.pyc +0 -0
  22. common/text/__pycache__/cleaners.cpython-37.pyc +0 -0
  23. common/text/__pycache__/cleaners.cpython-38.pyc +0 -0
  24. common/text/__pycache__/cleaners.cpython-39.pyc +0 -0
  25. common/text/__pycache__/cmudict.cpython-37.pyc +0 -0
  26. common/text/__pycache__/cmudict.cpython-38.pyc +0 -0
  27. common/text/__pycache__/cmudict.cpython-39.pyc +0 -0
  28. common/text/__pycache__/datestime.cpython-37.pyc +0 -0
  29. common/text/__pycache__/datestime.cpython-38.pyc +0 -0
  30. common/text/__pycache__/datestime.cpython-39.pyc +0 -0
  31. common/text/__pycache__/letters_and_numbers.cpython-37.pyc +0 -0
  32. common/text/__pycache__/letters_and_numbers.cpython-38.pyc +0 -0
  33. common/text/__pycache__/letters_and_numbers.cpython-39.pyc +0 -0
  34. common/text/__pycache__/numerical.cpython-37.pyc +0 -0
  35. common/text/__pycache__/numerical.cpython-38.pyc +0 -0
  36. common/text/__pycache__/numerical.cpython-39.pyc +0 -0
  37. common/text/__pycache__/symbols.cpython-37.pyc +0 -0
  38. common/text/__pycache__/symbols.cpython-38.pyc +0 -0
  39. common/text/__pycache__/symbols.cpython-39.pyc +0 -0
  40. common/text/__pycache__/text_processing.cpython-37.pyc +0 -0
  41. common/text/__pycache__/text_processing.cpython-38.pyc +0 -0
  42. common/text/__pycache__/text_processing.cpython-39.pyc +0 -0
  43. common/text/abbreviations.py +67 -0
  44. common/text/acronyms.py +109 -0
  45. common/text/cleaners.py +102 -0
  46. common/text/cmudict.py +98 -0
  47. common/text/datestime.py +22 -0
  48. common/text/letters_and_numbers.py +90 -0
  49. common/text/numerical.py +153 -0
  50. common/text/symbols.py +81 -0
README.md CHANGED
@@ -1,14 +1,15 @@
1
  ---
2
- title: 6L TTS
3
- emoji: 🐨
4
  colorFrom: green
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.34.2
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-nd-4.0
11
- short_description: Multilingual TTS for Sámi languages
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Multi Sami
3
+ emoji: 🔥
4
  colorFrom: green
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.15.0
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-nd-4.0
11
+ #license: cc-by-4.0
12
+ short_description: Multilingual, multi-speaker Sámi TTS
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import syn_hifigan as syn
3
+ #import syn_k_univnet_multi as syn
4
+ import os, tempfile
5
+
6
+ languages = {"South Sámi":0,
7
+ "North Sámi":1,
8
+ "Lule Sámi":2}
9
+
10
+ speakers={#"aj0": 0,
11
+ "Aanna - sma": 1,
12
+ "Máhtte": 2,
13
+ "Siggá - smj": 3,
14
+ "Biret - sme": 5,
15
+ #"lo": 6,
16
+ "Sunná": 7,
17
+ "Abmut - smj": 8,
18
+ "Nihkol - smj": 9
19
+ }
20
+ public=True
21
+
22
+ tempdir = tempfile.gettempdir()
23
+
24
+ tts = syn.Synthesizer()
25
+
26
+
27
+
28
+ def speak(text, language,speaker,l_weight, s_weight, pace, postfilter): #pitch_shift,pitch_std):
29
+
30
+
31
+
32
+ # text frontend not implemented...
33
+ text = text.replace("...", "…")
34
+ print(speakers[speaker])
35
+ audio = tts.speak(text, output_file=f'{tempdir}/tmp', lang=languages[language],
36
+ spkr=speakers[speaker], l_weight=l_weight, s_weight=s_weight,
37
+ pace=pace, clarity=postfilter)
38
+
39
+ if not public:
40
+ try:
41
+ os.system("play "+tempdir+"/tmp.wav &")
42
+ except:
43
+ pass
44
+
45
+ return (22050, audio)
46
+
47
+
48
+
49
+ controls = []
50
+ controls.append(gr.Textbox(label="text", value="Suohtas duinna deaivvadit."))
51
+ controls.append(gr.Dropdown(list(languages.keys()), label="language", value="North Sámi"))
52
+ controls.append(gr.Dropdown(list(speakers.keys()), label="speaker", value="Sunná"))
53
+
54
+ controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="language weight"))
55
+ controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="speaker weight"))
56
+
57
+ controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1.0, label="speech rate"))
58
+ controls.append(gr.Slider(minimum=0., maximum=2, step=0.05, value=1.0, label="post-processing"))
59
+
60
+
61
+
62
+
63
+ tts_gui = gr.Interface(
64
+ fn=speak,
65
+ inputs=controls,
66
+ outputs= gr.Audio(label="output"),
67
+ live=False
68
+
69
+ )
70
+
71
+
72
+ if __name__ == "__main__":
73
+ tts_gui.launch(share=public)
common/audio_processing.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *****************************************************************************
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of the NVIDIA CORPORATION nor the
12
+ # names of its contributors may be used to endorse or promote products
13
+ # derived from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
+ # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
+ # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19
+ # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
+ # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
+ # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22
+ # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
+ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ #
26
+ # *****************************************************************************
27
+
28
+ import librosa.util as librosa_util
29
+ import numpy as np
30
+ import torch
31
+ from scipy.signal import get_window
32
+
33
+
34
+ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
35
+ n_fft=800, dtype=np.float32, norm=None):
36
+ """
37
+ # from librosa 0.6
38
+ Compute the sum-square envelope of a window function at a given hop length.
39
+
40
+ This is used to estimate modulation effects induced by windowing
41
+ observations in short-time fourier transforms.
42
+
43
+ Parameters
44
+ ----------
45
+ window : string, tuple, number, callable, or list-like
46
+ Window specification, as in `get_window`
47
+
48
+ n_frames : int > 0
49
+ The number of analysis frames
50
+
51
+ hop_length : int > 0
52
+ The number of samples to advance between frames
53
+
54
+ win_length : [optional]
55
+ The length of the window function. By default, this matches `n_fft`.
56
+
57
+ n_fft : int > 0
58
+ The length of each analysis frame.
59
+
60
+ dtype : np.dtype
61
+ The data type of the output
62
+
63
+ Returns
64
+ -------
65
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
66
+ The sum-squared envelope of the window function
67
+ """
68
+ if win_length is None:
69
+ win_length = n_fft
70
+
71
+ n = n_fft + hop_length * (n_frames - 1)
72
+ x = np.zeros(n, dtype=dtype)
73
+
74
+ # Compute the squared window at the desired length
75
+ win_sq = get_window(window, win_length, fftbins=True)
76
+ win_sq = librosa_util.normalize(win_sq, norm=norm)**2
77
+ win_sq = librosa_util.pad_center(win_sq, size=n_fft)
78
+
79
+ # Fill the envelope
80
+ for i in range(n_frames):
81
+ sample = i * hop_length
82
+ x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
83
+ return x
84
+
85
+
86
+ def griffin_lim(magnitudes, stft_fn, n_iters=30):
87
+ """
88
+ PARAMS
89
+ ------
90
+ magnitudes: spectrogram magnitudes
91
+ stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
92
+ """
93
+
94
+ angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
95
+ angles = angles.astype(np.float32)
96
+ angles = torch.autograd.Variable(torch.from_numpy(angles))
97
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
98
+
99
+ for i in range(n_iters):
100
+ _, angles = stft_fn.transform(signal)
101
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
102
+ return signal
103
+
104
+
105
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
106
+ """
107
+ PARAMS
108
+ ------
109
+ C: compression factor
110
+ """
111
+ return torch.log(torch.clamp(x, min=clip_val) * C)
112
+
113
+
114
+ def dynamic_range_decompression(x, C=1):
115
+ """
116
+ PARAMS
117
+ ------
118
+ C: compression factor used to compress
119
+ """
120
+ return torch.exp(x) / C
common/env.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from collections import defaultdict
4
+
5
+
6
+ class AttrDict(dict):
7
+ def __init__(self, *args, **kwargs):
8
+ super(AttrDict, self).__init__(*args, **kwargs)
9
+ self.__dict__ = self
10
+
11
+
12
+ class DefaultAttrDict(defaultdict):
13
+ def __init__(self, *args, **kwargs):
14
+ super(DefaultAttrDict, self).__init__(*args, **kwargs)
15
+ self.__dict__ = self
16
+
17
+ def __getattr__(self, item):
18
+ return self[item]
19
+
20
+
21
+ def build_env(config, config_name, path):
22
+ t_path = os.path.join(path, config_name)
23
+ if config != t_path:
24
+ os.makedirs(path, exist_ok=True)
25
+ shutil.copyfile(config, os.path.join(path, config_name))
common/filter_warnings.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Mutes known and unrelated PyTorch warnings.
16
+
17
+ The warnings module keeps a list of filters. Importing it as late as possible
18
+ prevents its filters from being overriden.
19
+ """
20
+
21
+ import warnings
22
+
23
+
24
+ # NGC 22.04-py3 container (PyTorch 1.12.0a0+bd13bc6)
25
+ warnings.filterwarnings(
26
+ "ignore",
27
+ message='positional arguments and argument "destination" are deprecated.'
28
+ ' nn.Module.state_dict will not accept them in the future.')
29
+
30
+ # 22.08-py3 container
31
+ warnings.filterwarnings(
32
+ "ignore",
33
+ message="is_namedtuple is deprecated, please use the python checks")
common/gpu_affinity.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import collections
16
+ import math
17
+ import os
18
+ import pathlib
19
+ import re
20
+
21
+ import pynvml
22
+
23
+ pynvml.nvmlInit()
24
+
25
+
26
+ def systemGetDriverVersion():
27
+ return pynvml.nvmlSystemGetDriverVersion()
28
+
29
+
30
+ def deviceGetCount():
31
+ return pynvml.nvmlDeviceGetCount()
32
+
33
+
34
+ class device:
35
+ # assume nvml returns list of 64 bit ints
36
+ _nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
37
+
38
+ def __init__(self, device_idx):
39
+ super().__init__()
40
+ self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
41
+
42
+ def getName(self):
43
+ return pynvml.nvmlDeviceGetName(self.handle)
44
+
45
+ def getCpuAffinity(self):
46
+ affinity_string = ''
47
+ for j in pynvml.nvmlDeviceGetCpuAffinity(
48
+ self.handle, device._nvml_affinity_elements
49
+ ):
50
+ # assume nvml returns list of 64 bit ints
51
+ affinity_string = '{:064b}'.format(j) + affinity_string
52
+ affinity_list = [int(x) for x in affinity_string]
53
+ affinity_list.reverse() # so core 0 is in 0th element of list
54
+
55
+ ret = [i for i, e in enumerate(affinity_list) if e != 0]
56
+ return ret
57
+
58
+
59
+ def set_socket_affinity(gpu_id):
60
+ dev = device(gpu_id)
61
+ affinity = dev.getCpuAffinity()
62
+ os.sched_setaffinity(0, affinity)
63
+
64
+
65
+ def set_single_affinity(gpu_id):
66
+ dev = device(gpu_id)
67
+ affinity = dev.getCpuAffinity()
68
+ os.sched_setaffinity(0, affinity[:1])
69
+
70
+
71
+ def set_single_unique_affinity(gpu_id, nproc_per_node):
72
+ devices = [device(i) for i in range(nproc_per_node)]
73
+ socket_affinities = [dev.getCpuAffinity() for dev in devices]
74
+
75
+ siblings_list = get_thread_siblings_list()
76
+ siblings_dict = dict(siblings_list)
77
+
78
+ # remove siblings
79
+ for idx, socket_affinity in enumerate(socket_affinities):
80
+ socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
81
+
82
+ affinities = []
83
+ assigned = []
84
+
85
+ for socket_affinity in socket_affinities:
86
+ for core in socket_affinity:
87
+ if core not in assigned:
88
+ affinities.append([core])
89
+ assigned.append(core)
90
+ break
91
+ os.sched_setaffinity(0, affinities[gpu_id])
92
+
93
+
94
+ def set_socket_unique_affinity(gpu_id, nproc_per_node, mode):
95
+ device_ids = [device(i) for i in range(nproc_per_node)]
96
+ socket_affinities = [dev.getCpuAffinity() for dev in device_ids]
97
+
98
+ siblings_list = get_thread_siblings_list()
99
+ siblings_dict = dict(siblings_list)
100
+
101
+ # remove siblings
102
+ for idx, socket_affinity in enumerate(socket_affinities):
103
+ socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
104
+
105
+ socket_affinities_to_device_ids = collections.defaultdict(list)
106
+
107
+ for idx, socket_affinity in enumerate(socket_affinities):
108
+ socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)
109
+
110
+ for socket_affinity, device_ids in socket_affinities_to_device_ids.items():
111
+ devices_per_group = len(device_ids)
112
+ cores_per_device = len(socket_affinity) // devices_per_group
113
+ for group_id, device_id in enumerate(device_ids):
114
+ if device_id == gpu_id:
115
+ if mode == 'interleaved':
116
+ affinity = list(socket_affinity[group_id::devices_per_group])
117
+ elif mode == 'continuous':
118
+ affinity = list(socket_affinity[group_id*cores_per_device:(group_id+1)*cores_per_device])
119
+ else:
120
+ raise RuntimeError('Unknown set_socket_unique_affinity mode')
121
+
122
+ # reintroduce siblings
123
+ affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]
124
+ os.sched_setaffinity(0, affinity)
125
+
126
+
127
+ def get_thread_siblings_list():
128
+ path = '/sys/devices/system/cpu/cpu*/topology/thread_siblings_list'
129
+ thread_siblings_list = []
130
+ pattern = re.compile(r'(\d+)\D(\d+)')
131
+ for fname in pathlib.Path(path[0]).glob(path[1:]):
132
+ with open(fname) as f:
133
+ content = f.read().strip()
134
+ res = pattern.findall(content)
135
+ if res:
136
+ pair = tuple(map(int, res[0]))
137
+ thread_siblings_list.append(pair)
138
+ return thread_siblings_list
139
+
140
+
141
+ def set_affinity(gpu_id, nproc_per_node, mode='socket'):
142
+ if mode == 'socket':
143
+ set_socket_affinity(gpu_id)
144
+ elif mode == 'single':
145
+ set_single_affinity(gpu_id)
146
+ elif mode == 'single_unique':
147
+ set_single_unique_affinity(gpu_id, nproc_per_node)
148
+ elif mode == 'socket_unique_interleaved':
149
+ set_socket_unique_affinity(gpu_id, nproc_per_node, 'interleaved')
150
+ elif mode == 'socket_unique_continuous':
151
+ set_socket_unique_affinity(gpu_id, nproc_per_node, 'continuous')
152
+ else:
153
+ raise RuntimeError('Unknown affinity mode')
154
+
155
+ affinity = os.sched_getaffinity(0)
156
+ return affinity
common/layers.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *****************************************************************************
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of the NVIDIA CORPORATION nor the
12
+ # names of its contributors may be used to endorse or promote products
13
+ # derived from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
+ # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
+ # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19
+ # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
+ # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
+ # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22
+ # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
+ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ #
26
+ # *****************************************************************************
27
+
28
+ import torch
29
+ import torch.nn.functional as F
30
+ from librosa.filters import mel as librosa_mel_fn
31
+
32
+ from common.audio_processing import (dynamic_range_compression,
33
+ dynamic_range_decompression)
34
+ from common.stft import STFT
35
+
36
+
37
+ class LinearNorm(torch.nn.Module):
38
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
39
+ super(LinearNorm, self).__init__()
40
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
41
+
42
+ torch.nn.init.xavier_uniform_(
43
+ self.linear_layer.weight,
44
+ gain=torch.nn.init.calculate_gain(w_init_gain))
45
+
46
+ def forward(self, x):
47
+ return self.linear_layer(x)
48
+
49
+
50
+ class ConvNorm(torch.nn.Module):
51
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
52
+ padding=None, dilation=1, bias=True, w_init_gain='linear',
53
+ batch_norm=False):
54
+ super(ConvNorm, self).__init__()
55
+ if padding is None:
56
+ assert(kernel_size % 2 == 1)
57
+ padding = int(dilation * (kernel_size - 1) / 2)
58
+
59
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
60
+ kernel_size=kernel_size, stride=stride,
61
+ padding=padding, dilation=dilation,
62
+ bias=bias)
63
+ self.norm = torch.nn.BatchNorm1D(out_channels) if batch_norm else None
64
+
65
+ torch.nn.init.xavier_uniform_(
66
+ self.conv.weight,
67
+ gain=torch.nn.init.calculate_gain(w_init_gain))
68
+
69
+ def forward(self, signal):
70
+ if self.norm is None:
71
+ return self.conv(signal)
72
+ else:
73
+ return self.norm(self.conv(signal))
74
+
75
+
76
+ class ConvReLUNorm(torch.nn.Module):
77
+ def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0):
78
+ super(ConvReLUNorm, self).__init__()
79
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
80
+ kernel_size=kernel_size,
81
+ padding=(kernel_size // 2))
82
+ self.norm = torch.nn.LayerNorm(out_channels)
83
+ self.dropout = torch.nn.Dropout(dropout)
84
+
85
+ def forward(self, signal):
86
+ out = F.relu(self.conv(signal))
87
+ out = self.norm(out.transpose(1, 2)).transpose(1, 2).to(signal.dtype)
88
+ return self.dropout(out)
89
+
90
+
91
+ class TacotronSTFT(torch.nn.Module):
92
+ def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
93
+ n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
94
+ mel_fmax=8000.0):
95
+ super(TacotronSTFT, self).__init__()
96
+ self.n_mel_channels = n_mel_channels
97
+ self.sampling_rate = sampling_rate
98
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
99
+ mel_basis = librosa_mel_fn(
100
+ sr=sampling_rate,
101
+ n_fft=filter_length,
102
+ n_mels=n_mel_channels,
103
+ fmin=mel_fmin,
104
+ fmax=mel_fmax
105
+ )
106
+ mel_basis = torch.from_numpy(mel_basis).float()
107
+ self.register_buffer('mel_basis', mel_basis)
108
+
109
+ def spectral_normalize(self, magnitudes):
110
+ output = dynamic_range_compression(magnitudes)
111
+ return output
112
+
113
+ def spectral_de_normalize(self, magnitudes):
114
+ output = dynamic_range_decompression(magnitudes)
115
+ return output
116
+
117
+ def mel_spectrogram(self, y):
118
+ """Computes mel-spectrograms from a batch of waves
119
+ PARAMS
120
+ ------
121
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
122
+
123
+ RETURNS
124
+ -------
125
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
126
+ """
127
+ assert(torch.min(y.data) >= -1)
128
+ assert(torch.max(y.data) <= 1)
129
+
130
+ magnitudes, phases = self.stft_fn.transform(y)
131
+ magnitudes = magnitudes.data
132
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
133
+ mel_output = self.spectral_normalize(mel_output)
134
+ return mel_output
common/repeated_dataloader.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Data pipeline elements which wrap the data N times
16
+
17
+ A RepeatedDataLoader resets its iterator less frequently. This saves time
18
+ on multi-GPU platforms and is invisible to the training loop.
19
+
20
+ NOTE: Repeating puts a block of (len(dataset) * repeats) int64s into RAM.
21
+ Do not use more repeats than necessary (e.g., 10**6 to simulate infinity).
22
+ """
23
+
24
+ import itertools
25
+
26
+ from torch.utils.data import DataLoader
27
+ from torch.utils.data.distributed import DistributedSampler
28
+
29
+
30
+ class RepeatedDataLoader(DataLoader):
31
+ def __init__(self, repeats, *args, **kwargs):
32
+ self.repeats = repeats
33
+ super().__init__(*args, **kwargs)
34
+
35
+ def __iter__(self):
36
+ if self._iterator is None or self.repeats_done >= self.repeats:
37
+ self.repeats_done = 1
38
+ return super().__iter__()
39
+ else:
40
+ self.repeats_done += 1
41
+ return self._iterator
42
+
43
+
44
+ class RepeatedDistributedSampler(DistributedSampler):
45
+ def __init__(self, repeats, *args, **kwargs):
46
+ self.repeats = repeats
47
+ assert self.repeats <= 10000, "Too many repeats overload RAM."
48
+ super().__init__(*args, **kwargs)
49
+
50
+ def __iter__(self):
51
+ # Draw indices for `self.repeats` epochs forward
52
+ start_epoch = self.epoch
53
+ iters = []
54
+ for r in range(self.repeats):
55
+ self.set_epoch(start_epoch + r)
56
+ iters.append(super().__iter__())
57
+ self.set_epoch(start_epoch)
58
+
59
+ return itertools.chain.from_iterable(iters)
common/stft.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BSD 3-Clause License
3
+
4
+ Copyright (c) 2017, Prem Seetharaman
5
+ All rights reserved.
6
+
7
+ * Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice,
11
+ this list of conditions and the following disclaimer.
12
+
13
+ * Redistributions in binary form must reproduce the above copyright notice, this
14
+ list of conditions and the following disclaimer in the
15
+ documentation and/or other materials provided with the distribution.
16
+
17
+ * Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived from this
19
+ software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ """
32
+
33
+ import torch
34
+ import numpy as np
35
+ import torch.nn.functional as F
36
+ from torch.autograd import Variable
37
+ from scipy.signal import get_window
38
+ from librosa.util import pad_center, tiny
39
+ from common.audio_processing import window_sumsquare
40
+
41
+
42
+ class STFT(torch.nn.Module):
43
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
44
+ def __init__(self, filter_length=800, hop_length=200, win_length=800,
45
+ window='hann', device="cpu"):
46
+ super(STFT, self).__init__()
47
+ self.filter_length = filter_length
48
+ self.hop_length = hop_length
49
+ self.win_length = win_length
50
+ self.window = window
51
+ self.forward_transform = None
52
+ scale = self.filter_length / self.hop_length
53
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
54
+
55
+ cutoff = int((self.filter_length / 2 + 1))
56
+ fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
57
+ np.imag(fourier_basis[:cutoff, :])])
58
+
59
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
60
+ inverse_basis = torch.FloatTensor(
61
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :].copy())
62
+
63
+ if window is not None:
64
+ assert(filter_length >= win_length)
65
+ # get window and zero center pad it to filter_length
66
+ fft_window = get_window(window, win_length, fftbins=True)
67
+ fft_window = pad_center(fft_window, size=filter_length)
68
+ fft_window = torch.from_numpy(fft_window).float()
69
+
70
+ # window the bases
71
+ forward_basis *= fft_window
72
+ inverse_basis *= fft_window
73
+
74
+ self.register_buffer('forward_basis', forward_basis.float().to(device))
75
+ self.register_buffer('inverse_basis', inverse_basis.float().to(device))
76
+
77
+ def transform(self, input_data):
78
+ num_batches = input_data.size(0)
79
+ num_samples = input_data.size(1)
80
+
81
+ self.num_samples = num_samples
82
+
83
+ # similar to librosa, reflect-pad the input
84
+ input_data = input_data.view(num_batches, 1, num_samples)
85
+ input_data = F.pad(
86
+ input_data.unsqueeze(1),
87
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
88
+ mode='reflect')
89
+ input_data = input_data.squeeze(1)
90
+ # print(self.forward_basis.device)
91
+ forward_transform = F.conv1d(
92
+ input_data,
93
+ Variable(self.forward_basis, requires_grad=False),
94
+ stride=self.hop_length,
95
+ padding=0)
96
+
97
+ cutoff = int((self.filter_length / 2) + 1)
98
+ real_part = forward_transform[:, :cutoff, :]
99
+ imag_part = forward_transform[:, cutoff:, :]
100
+
101
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
102
+ phase = torch.autograd.Variable(
103
+ torch.atan2(imag_part.data, real_part.data))
104
+
105
+ return magnitude, phase
106
+
107
+ def inverse(self, magnitude, phase):
108
+ recombine_magnitude_phase = torch.cat(
109
+ [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
110
+
111
+ with torch.no_grad():
112
+ inverse_transform = F.conv_transpose1d(
113
+ recombine_magnitude_phase, self.inverse_basis,
114
+ stride=self.hop_length, padding=0)
115
+
116
+ if self.window is not None:
117
+ window_sum = window_sumsquare(
118
+ self.window, magnitude.size(-1), hop_length=self.hop_length,
119
+ win_length=self.win_length, n_fft=self.filter_length,
120
+ dtype=np.float32)
121
+ # remove modulation effects
122
+ approx_nonzero_indices = torch.from_numpy(
123
+ np.where(window_sum > tiny(window_sum))[0])
124
+ window_sum = torch.autograd.Variable(
125
+ torch.from_numpy(window_sum), requires_grad=False)
126
+ window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
127
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
128
+
129
+ # scale by hop ratio
130
+ inverse_transform *= float(self.filter_length) / self.hop_length
131
+
132
+ inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
133
+ inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
134
+
135
+ return inverse_transform
136
+
137
+ def forward(self, input_data):
138
+ self.magnitude, self.phase = self.transform(input_data)
139
+ reconstruction = self.inverse(self.magnitude, self.phase)
140
+ return reconstruction
common/tb_dllogger.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import glob
3
+ import re
4
+ from itertools import product
5
+ from pathlib import Path
6
+
7
+ import dllogger
8
+ import torch
9
+ import numpy as np
10
+ from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
11
+ from torch.utils.tensorboard import SummaryWriter
12
+
13
+
14
+ tb_loggers = {}
15
+
16
+
17
+ class TBLogger:
18
+ """
19
+ xyz_dummies: stretch the screen with empty plots so the legend would
20
+ always fit for other plots
21
+ """
22
+ def __init__(self, enabled, log_dir, name, interval=1, dummies=True):
23
+ self.enabled = enabled
24
+ self.interval = interval
25
+ self.cache = {}
26
+ if self.enabled:
27
+ self.summary_writer = SummaryWriter(
28
+ log_dir=Path(log_dir, name), flush_secs=120, max_queue=200)
29
+ atexit.register(self.summary_writer.close)
30
+ if dummies:
31
+ for key in ('_', '✕'):
32
+ self.summary_writer.add_scalar(key, 0.0, 1)
33
+
34
+ def log(self, step, data):
35
+ for k, v in data.items():
36
+ self.log_value(step, k, v.item() if type(v) is torch.Tensor else v)
37
+
38
+ def log_value(self, step, key, val, stat='mean'):
39
+ if self.enabled:
40
+ if key not in self.cache:
41
+ self.cache[key] = []
42
+ self.cache[key].append(val)
43
+ if len(self.cache[key]) == self.interval:
44
+ agg_val = getattr(np, stat)(self.cache[key])
45
+ self.summary_writer.add_scalar(key, agg_val, step)
46
+ del self.cache[key]
47
+
48
+ def log_grads(self, step, model):
49
+ if self.enabled:
50
+ norms = [p.grad.norm().item() for p in model.parameters()
51
+ if p.grad is not None]
52
+ for stat in ('max', 'min', 'mean'):
53
+ self.log_value(step, f'grad_{stat}', getattr(np, stat)(norms),
54
+ stat=stat)
55
+
56
+
57
+ def unique_log_fpath(fpath):
58
+
59
+ if not Path(fpath).is_file():
60
+ return fpath
61
+
62
+ # Avoid overwriting old logs
63
+ saved = [re.search('\.(\d+)$', f) for f in glob.glob(f'{fpath}.*')]
64
+ saved = [0] + [int(m.group(1)) for m in saved if m is not None]
65
+ return f'{fpath}.{max(saved) + 1}'
66
+
67
+
68
+ def stdout_step_format(step):
69
+ if isinstance(step, str):
70
+ return step
71
+ fields = []
72
+ if len(step) > 0:
73
+ fields.append("epoch {:>4}".format(step[0]))
74
+ if len(step) > 1:
75
+ fields.append("iter {:>3}".format(step[1]))
76
+ if len(step) > 2:
77
+ fields[-1] += "/{}".format(step[2])
78
+ return " | ".join(fields)
79
+
80
+
81
+ def stdout_metric_format(metric, metadata, value):
82
+ name = metadata.get("name", metric + " : ")
83
+ unit = metadata.get("unit", None)
84
+ format = f'{{{metadata.get("format", "")}}}'
85
+ fields = [name, format.format(value) if value is not None else value, unit]
86
+ fields = [f for f in fields if f is not None]
87
+ return "| " + " ".join(fields)
88
+
89
+
90
+ def init(log_fpath, log_dir, enabled=True, tb_subsets=[], **tb_kw):
91
+
92
+ if enabled:
93
+ backends = [JSONStreamBackend(Verbosity.DEFAULT,
94
+ unique_log_fpath(log_fpath)),
95
+ StdOutBackend(Verbosity.VERBOSE,
96
+ step_format=stdout_step_format,
97
+ metric_format=stdout_metric_format)]
98
+ else:
99
+ backends = []
100
+
101
+ dllogger.init(backends=backends)
102
+ dllogger.metadata("train_lrate", {"name": "lrate", "unit": None, "format": ":>3.2e"})
103
+
104
+ for id_, pref in [('train', ''), ('train_avg', 'avg train '),
105
+ ('val', ' avg val '), ('val_ema', ' EMA val ')]:
106
+
107
+ dllogger.metadata(f"{id_}_loss",
108
+ {"name": f"{pref}loss", "unit": None, "format": ":>5.2f"})
109
+ dllogger.metadata(f"{id_}_mel_loss",
110
+ {"name": f"{pref}mel loss", "unit": None, "format": ":>5.2f"})
111
+
112
+ dllogger.metadata(f"{id_}_kl_loss",
113
+ {"name": f"{pref}kl loss", "unit": None, "format": ":>5.5f"})
114
+ dllogger.metadata(f"{id_}_kl_weight",
115
+ {"name": f"{pref}kl weight", "unit": None, "format": ":>5.5f"})
116
+
117
+ dllogger.metadata(f"{id_}_frames/s",
118
+ {"name": None, "unit": "frames/s", "format": ":>10.2f"})
119
+ dllogger.metadata(f"{id_}_took",
120
+ {"name": "took", "unit": "s", "format": ":>3.2f"})
121
+
122
+ global tb_loggers
123
+ tb_loggers = {s: TBLogger(enabled, log_dir, name=s, **tb_kw)
124
+ for s in tb_subsets}
125
+
126
+
127
+ def init_inference_metadata(batch_size=None):
128
+
129
+ modalities = [('latency', 's', ':>10.5f'), ('RTF', 'x', ':>10.2f'),
130
+ ('frames/s', 'frames/s', ':>10.2f'), ('samples/s', 'samples/s', ':>10.2f'),
131
+ ('letters/s', 'letters/s', ':>10.2f'), ('tokens/s', 'tokens/s', ':>10.2f')]
132
+
133
+ if batch_size is not None:
134
+ modalities.append((f'RTF@{batch_size}', 'x', ':>10.2f'))
135
+
136
+ percs = ['', 'avg', '90%', '95%', '99%']
137
+ models = ['', 'fastpitch', 'waveglow', 'hifigan']
138
+
139
+ for perc, model, (mod, unit, fmt) in product(percs, models, modalities):
140
+ name = f'{perc} {model} {mod}'.strip().replace(' ', ' ')
141
+ dllogger.metadata(name.replace(' ', '_'),
142
+ {'name': f'{name: <26}', 'unit': unit, 'format': fmt})
143
+
144
+
145
+ def log(step, tb_total_steps=None, data={}, subset='train'):
146
+ if tb_total_steps is not None:
147
+ tb_loggers[subset].log(tb_total_steps, data)
148
+
149
+ if subset != '':
150
+ data = {f'{subset}_{key}': v for key, v in data.items()}
151
+ dllogger.log(step, data=data)
152
+
153
+
154
+ def log_grads_tb(tb_total_steps, grads, tb_subset='train'):
155
+ tb_loggers[tb_subset].log_grads(tb_total_steps, grads)
156
+
157
+
158
+ def parameters(data, verbosity=0, tb_subset=None):
159
+ for k, v in data.items():
160
+ dllogger.log(step="PARAMETER", data={k: v}, verbosity=verbosity)
161
+
162
+ if tb_subset is not None and tb_loggers[tb_subset].enabled:
163
+ tb_data = {k: v for k, v in data.items()
164
+ if type(v) in (str, bool, int, float)}
165
+ tb_loggers[tb_subset].summary_writer.add_hparams(tb_data, {})
166
+
167
+
168
+ def flush():
169
+ dllogger.flush()
170
+ for tbl in tb_loggers.values():
171
+ if tbl.enabled:
172
+ tbl.summary_writer.flush()
common/text/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2017 Keith Ito
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
common/text/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .cmudict import CMUDict
2
+
3
+ cmudict = CMUDict()
common/text/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (203 Bytes). View file
 
common/text/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (230 Bytes). View file
 
common/text/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (230 Bytes). View file
 
common/text/__pycache__/abbreviations.cpython-37.pyc ADDED
Binary file (1.84 kB). View file
 
common/text/__pycache__/abbreviations.cpython-38.pyc ADDED
Binary file (1.87 kB). View file
 
common/text/__pycache__/abbreviations.cpython-39.pyc ADDED
Binary file (1.87 kB). View file
 
common/text/__pycache__/acronyms.cpython-37.pyc ADDED
Binary file (2.55 kB). View file
 
common/text/__pycache__/acronyms.cpython-38.pyc ADDED
Binary file (2.68 kB). View file
 
common/text/__pycache__/acronyms.cpython-39.pyc ADDED
Binary file (2.54 kB). View file
 
common/text/__pycache__/cleaners.cpython-37.pyc ADDED
Binary file (2.7 kB). View file
 
common/text/__pycache__/cleaners.cpython-38.pyc ADDED
Binary file (2.77 kB). View file
 
common/text/__pycache__/cleaners.cpython-39.pyc ADDED
Binary file (2.75 kB). View file
 
common/text/__pycache__/cmudict.cpython-37.pyc ADDED
Binary file (3.79 kB). View file
 
common/text/__pycache__/cmudict.cpython-38.pyc ADDED
Binary file (3.99 kB). View file
 
common/text/__pycache__/cmudict.cpython-39.pyc ADDED
Binary file (3.69 kB). View file
 
common/text/__pycache__/datestime.cpython-37.pyc ADDED
Binary file (720 Bytes). View file
 
common/text/__pycache__/datestime.cpython-38.pyc ADDED
Binary file (757 Bytes). View file
 
common/text/__pycache__/datestime.cpython-39.pyc ADDED
Binary file (757 Bytes). View file
 
common/text/__pycache__/letters_and_numbers.cpython-37.pyc ADDED
Binary file (2.88 kB). View file
 
common/text/__pycache__/letters_and_numbers.cpython-38.pyc ADDED
Binary file (2.94 kB). View file
 
common/text/__pycache__/letters_and_numbers.cpython-39.pyc ADDED
Binary file (2.92 kB). View file
 
common/text/__pycache__/numerical.cpython-37.pyc ADDED
Binary file (4.65 kB). View file
 
common/text/__pycache__/numerical.cpython-38.pyc ADDED
Binary file (4.7 kB). View file
 
common/text/__pycache__/numerical.cpython-39.pyc ADDED
Binary file (4.71 kB). View file
 
common/text/__pycache__/symbols.cpython-37.pyc ADDED
Binary file (1.5 kB). View file
 
common/text/__pycache__/symbols.cpython-38.pyc ADDED
Binary file (1.96 kB). View file
 
common/text/__pycache__/symbols.cpython-39.pyc ADDED
Binary file (2.16 kB). View file
 
common/text/__pycache__/text_processing.cpython-37.pyc ADDED
Binary file (4.96 kB). View file
 
common/text/__pycache__/text_processing.cpython-38.pyc ADDED
Binary file (5.04 kB). View file
 
common/text/__pycache__/text_processing.cpython-39.pyc ADDED
Binary file (4.36 kB). View file
 
common/text/abbreviations.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ _no_period_re = re.compile(r'(No[.])(?=[ ]?[0-9])')
4
+ _percent_re = re.compile(r'([ ]?[%])')
5
+ _half_re = re.compile('([0-9]½)|(½)')
6
+ _url_re = re.compile(r'([a-zA-Z])\.(com|gov|org)')
7
+
8
+
9
+ # List of (regular expression, replacement) pairs for abbreviations:
10
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
11
+ ('mrs', 'misess'),
12
+ ('ms', 'miss'),
13
+ ('mr', 'mister'),
14
+ ('dr', 'doctor'),
15
+ ('st', 'saint'),
16
+ ('co', 'company'),
17
+ ('jr', 'junior'),
18
+ ('maj', 'major'),
19
+ ('gen', 'general'),
20
+ ('drs', 'doctors'),
21
+ ('rev', 'reverend'),
22
+ ('lt', 'lieutenant'),
23
+ ('hon', 'honorable'),
24
+ ('sgt', 'sergeant'),
25
+ ('capt', 'captain'),
26
+ ('esq', 'esquire'),
27
+ ('ltd', 'limited'),
28
+ ('col', 'colonel'),
29
+ ('ft', 'fort'),
30
+ ('sen', 'senator'),
31
+ ('etc', 'et cetera'),
32
+ ]]
33
+
34
+
35
+ def _expand_no_period(m):
36
+ word = m.group(0)
37
+ if word[0] == 'N':
38
+ return 'Number'
39
+ return 'number'
40
+
41
+
42
+ def _expand_percent(m):
43
+ return ' percent'
44
+
45
+
46
+ def _expand_half(m):
47
+ word = m.group(1)
48
+ if word is None:
49
+ return 'half'
50
+ return word[0] + ' and a half'
51
+
52
+
53
+ def _expand_urls(m):
54
+ return f'{m.group(1)} dot {m.group(2)}'
55
+
56
+
57
+ def normalize_abbreviations(text):
58
+ text = re.sub(_no_period_re, _expand_no_period, text)
59
+ text = re.sub(_percent_re, _expand_percent, text)
60
+ text = re.sub(_half_re, _expand_half, text)
61
+ text = re.sub('&', ' and ', text)
62
+ text = re.sub('@', ' at ', text)
63
+ text = re.sub(_url_re, _expand_urls, text)
64
+
65
+ for regex, replacement in _abbreviations:
66
+ text = re.sub(regex, replacement, text)
67
+ return text
common/text/acronyms.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from . import cmudict
3
+
4
+ _letter_to_arpabet = {
5
+ 'A': 'EY1',
6
+ 'B': 'B IY1',
7
+ 'C': 'S IY1',
8
+ 'D': 'D IY1',
9
+ 'E': 'IY1',
10
+ 'F': 'EH1 F',
11
+ 'G': 'JH IY1',
12
+ 'H': 'EY1 CH',
13
+ 'I': 'AY1',
14
+ 'J': 'JH EY1',
15
+ 'K': 'K EY1',
16
+ 'L': 'EH1 L',
17
+ 'M': 'EH1 M',
18
+ 'N': 'EH1 N',
19
+ 'O': 'OW1',
20
+ 'P': 'P IY1',
21
+ 'Q': 'K Y UW1',
22
+ 'R': 'AA1 R',
23
+ 'S': 'EH1 S',
24
+ 'T': 'T IY1',
25
+ 'U': 'Y UW1',
26
+ 'V': 'V IY1',
27
+ 'X': 'EH1 K S',
28
+ 'Y': 'W AY1',
29
+ 'W': 'D AH1 B AH0 L Y UW0',
30
+ 'Z': 'Z IY1',
31
+ 's': 'Z'
32
+ }
33
+
34
+ # Acronyms that should not be expanded
35
+ hardcoded_acronyms = [
36
+ 'BMW', 'MVD', 'WDSU', 'GOP', 'UK', 'AI', 'GPS', 'BP', 'FBI', 'HD',
37
+ 'CES', 'LRA', 'PC', 'NBA', 'BBL', 'OS', 'IRS', 'SAC', 'UV', 'CEO', 'TV',
38
+ 'CNN', 'MSS', 'GSA', 'USSR', 'DNA', 'PRS', 'TSA', 'US', 'GPU', 'USA',
39
+ 'FPCC', 'CIA']
40
+
41
+ # Words and acronyms that should be read as regular words, e.g., NATO, HAPPY, etc.
42
+ uppercase_whiteliset = []
43
+
44
+ acronyms_exceptions = {
45
+ 'NVIDIA': 'N.VIDIA',
46
+ }
47
+
48
+ non_uppercase_exceptions = {
49
+ 'email': 'e-mail',
50
+ }
51
+
52
+ # must ignore roman numerals
53
+ _acronym_re = re.compile(r'([a-z]*[A-Z][A-Z]+)s?\.?')
54
+ _non_uppercase_re = re.compile(r'\b({})\b'.format('|'.join(non_uppercase_exceptions.keys())), re.IGNORECASE)
55
+
56
+
57
+ def _expand_acronyms_to_arpa(m, add_spaces=True):
58
+ acronym = m.group(0)
59
+
60
+ # remove dots if they exist
61
+ acronym = re.sub('\.', '', acronym)
62
+
63
+ acronym = "".join(acronym.split())
64
+ arpabet = cmudict.lookup(acronym)
65
+
66
+ if arpabet is None:
67
+ acronym = list(acronym)
68
+ arpabet = ["{" + _letter_to_arpabet[letter] + "}" for letter in acronym]
69
+ # temporary fix
70
+ if arpabet[-1] == '{Z}' and len(arpabet) > 1:
71
+ arpabet[-2] = arpabet[-2][:-1] + ' ' + arpabet[-1][1:]
72
+ del arpabet[-1]
73
+
74
+ arpabet = ' '.join(arpabet)
75
+ elif len(arpabet) == 1:
76
+ arpabet = "{" + arpabet[0] + "}"
77
+ else:
78
+ arpabet = acronym
79
+
80
+ return arpabet
81
+
82
+
83
+ def normalize_acronyms(text):
84
+ text = re.sub(_acronym_re, _expand_acronyms_to_arpa, text)
85
+ return text
86
+
87
+
88
+ def expand_acronyms(m):
89
+ text = m.group(1)
90
+ if text in acronyms_exceptions:
91
+ text = acronyms_exceptions[text]
92
+ elif text in uppercase_whiteliset:
93
+ text = text
94
+ else:
95
+ text = '.'.join(text) + '.'
96
+
97
+ if 's' in m.group(0):
98
+ text = text + '\'s'
99
+
100
+ if text[-1] != '.' and m.group(0)[-1] == '.':
101
+ return text + '.'
102
+ else:
103
+ return text
104
+
105
+
106
+ def spell_acronyms(text):
107
+ text = re.sub(_non_uppercase_re, lambda m: non_uppercase_exceptions[m.group(0).lower()], text)
108
+ text = re.sub(_acronym_re, expand_acronyms, text)
109
+ return text
common/text/cleaners.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ adapted from https://github.com/keithito/tacotron """
2
+
3
+ '''
4
+ Cleaners are transformations that run over the input text at both training and eval time.
5
+
6
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8
+ 1. "english_cleaners" for English text
9
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12
+ the symbols in symbols.py to match your data).
13
+ '''
14
+
15
+ import re
16
+ from .abbreviations import normalize_abbreviations
17
+ from .acronyms import normalize_acronyms, spell_acronyms
18
+ from .datestime import normalize_datestime
19
+ from .letters_and_numbers import normalize_letters_and_numbers
20
+ from .numerical import normalize_numbers
21
+ from .unidecoder import unidecoder
22
+
23
+
24
+ # Regular expression matching whitespace:
25
+ _whitespace_re = re.compile(r'\s+')
26
+
27
+
28
+ def expand_abbreviations(text):
29
+ return normalize_abbreviations(text)
30
+
31
+
32
+ def expand_numbers(text):
33
+ return normalize_numbers(text)
34
+
35
+
36
+ def expand_acronyms(text):
37
+ return normalize_acronyms(text)
38
+
39
+
40
+ def expand_datestime(text):
41
+ return normalize_datestime(text)
42
+
43
+
44
+ def expand_letters_and_numbers(text):
45
+ return normalize_letters_and_numbers(text)
46
+
47
+
48
+ def lowercase(text):
49
+ return text.lower()
50
+
51
+
52
+ def collapse_whitespace(text):
53
+ return re.sub(_whitespace_re, ' ', text)
54
+
55
+
56
+ def separate_acronyms(text):
57
+ text = re.sub(r"([0-9]+)([a-zA-Z]+)", r"\1 \2", text)
58
+ text = re.sub(r"([a-zA-Z]+)([0-9]+)", r"\1 \2", text)
59
+ return text
60
+
61
+
62
+ def convert_to_ascii(text):
63
+ return unidecoder(text)
64
+
65
+
66
+ def basic_cleaners(text):
67
+ '''Basic pipeline that collapses whitespace without transliteration.'''
68
+ # text = lowercase(text)
69
+ text = collapse_whitespace(text)
70
+ return text
71
+
72
+
73
+ def transliteration_cleaners(text):
74
+ '''Pipeline for non-English text that transliterates to ASCII.'''
75
+ text = convert_to_ascii(text)
76
+ text = lowercase(text)
77
+ text = collapse_whitespace(text)
78
+ return text
79
+
80
+
81
+ def english_cleaners(text):
82
+ '''Pipeline for English text, with number and abbreviation expansion.'''
83
+ text = convert_to_ascii(text)
84
+ text = lowercase(text)
85
+ text = expand_numbers(text)
86
+ text = expand_abbreviations(text)
87
+ text = collapse_whitespace(text)
88
+ return text
89
+
90
+
91
+ def english_cleaners_v2(text):
92
+ text = convert_to_ascii(text)
93
+ text = expand_datestime(text)
94
+ text = expand_letters_and_numbers(text)
95
+ text = expand_numbers(text)
96
+ text = expand_abbreviations(text)
97
+ text = spell_acronyms(text)
98
+ text = lowercase(text)
99
+ text = collapse_whitespace(text)
100
+ # compatibility with basic_english symbol set
101
+ text = re.sub(r'/+', ' ', text)
102
+ return text
common/text/cmudict.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ import re
4
+ import sys
5
+ import urllib.request
6
+ from pathlib import Path
7
+
8
+
9
+ valid_symbols = [
10
+ 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
11
+ 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
12
+ 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
13
+ 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
14
+ 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
15
+ 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
16
+ 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
17
+ ]
18
+
19
+ _valid_symbol_set = set(valid_symbols)
20
+
21
+
22
+ class CMUDict:
23
+ '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
24
+ def __init__(self, file_or_path=None, heteronyms_path=None, keep_ambiguous=True):
25
+ self._entries = {}
26
+ self.heteronyms = []
27
+ if file_or_path is not None:
28
+ self.initialize(file_or_path, heteronyms_path, keep_ambiguous)
29
+
30
+ def initialize(self, file_or_path, heteronyms_path, keep_ambiguous=True):
31
+ if isinstance(file_or_path, str):
32
+ if not Path(file_or_path).exists():
33
+ print("CMUdict missing. Downloading to data/cmudict/.")
34
+ self.download()
35
+
36
+ with open(file_or_path, encoding='latin-1') as f:
37
+ entries = _parse_cmudict(f)
38
+
39
+ else:
40
+ entries = _parse_cmudict(file_or_path)
41
+
42
+ if not keep_ambiguous:
43
+ entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
44
+ self._entries = entries
45
+
46
+ if heteronyms_path is not None:
47
+ with open(heteronyms_path, encoding='utf-8') as f:
48
+ self.heteronyms = [l.rstrip() for l in f]
49
+
50
+ def __len__(self):
51
+ if len(self._entries) == 0:
52
+ raise ValueError("CMUDict not initialized")
53
+ return len(self._entries)
54
+
55
+ def lookup(self, word):
56
+ '''Returns list of ARPAbet pronunciations of the given word.'''
57
+ if len(self._entries) == 0:
58
+ raise ValueError("CMUDict not initialized")
59
+ return self._entries.get(word.upper())
60
+
61
+ def download(self):
62
+ url = 'https://github.com/Alexir/CMUdict/raw/master/cmudict-0.7b'
63
+ try:
64
+ Path('cmudict').mkdir(parents=False, exist_ok=True)
65
+ urllib.request.urlretrieve(url, filename='cmudict/cmudict-0.7b')
66
+ except:
67
+ print("Automatic download of CMUdict failed. Try manually with:")
68
+ print()
69
+ print(" bash scripts/download_cmudict.sh")
70
+ print()
71
+ print("and re-run the script.")
72
+ sys.exit(0)
73
+
74
+
75
+ _alt_re = re.compile(r'\([0-9]+\)')
76
+
77
+
78
+ def _parse_cmudict(file):
79
+ cmudict = {}
80
+ for line in file:
81
+ if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
82
+ parts = line.split(' ')
83
+ word = re.sub(_alt_re, '', parts[0])
84
+ pronunciation = _get_pronunciation(parts[1])
85
+ if pronunciation:
86
+ if word in cmudict:
87
+ cmudict[word].append(pronunciation)
88
+ else:
89
+ cmudict[word] = [pronunciation]
90
+ return cmudict
91
+
92
+
93
+ def _get_pronunciation(s):
94
+ parts = s.strip().split(' ')
95
+ for part in parts:
96
+ if part not in _valid_symbol_set:
97
+ return None
98
+ return ' '.join(parts)
common/text/datestime.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ _ampm_re = re.compile(
3
+ r'([0-9]|0[0-9]|1[0-9]|2[0-3]):?([0-5][0-9])?\s*([AaPp][Mm]\b)')
4
+
5
+
6
+ def _expand_ampm(m):
7
+ matches = list(m.groups(0))
8
+ txt = matches[0]
9
+ txt = txt if int(matches[1]) == 0 else txt + ' ' + matches[1]
10
+
11
+ if matches[2][0].lower() == 'a':
12
+ txt += ' a.m.'
13
+ elif matches[2][0].lower() == 'p':
14
+ txt += ' p.m.'
15
+
16
+ return txt
17
+
18
+
19
+ def normalize_datestime(text):
20
+ text = re.sub(_ampm_re, _expand_ampm, text)
21
+ #text = re.sub(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])?", r"\1 \2", text)
22
+ return text
common/text/letters_and_numbers.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ _letters_and_numbers_re = re.compile(
3
+ r"((?:[a-zA-Z]+[0-9]|[0-9]+[a-zA-Z])[a-zA-Z0-9']*)", re.IGNORECASE)
4
+
5
+ _hardware_re = re.compile(
6
+ '([0-9]+(?:[.,][0-9]+)?)(?:\s?)(tb|gb|mb|kb|ghz|mhz|khz|hz|mm)', re.IGNORECASE)
7
+ _hardware_key = {'tb': 'terabyte',
8
+ 'gb': 'gigabyte',
9
+ 'mb': 'megabyte',
10
+ 'kb': 'kilobyte',
11
+ 'ghz': 'gigahertz',
12
+ 'mhz': 'megahertz',
13
+ 'khz': 'kilohertz',
14
+ 'hz': 'hertz',
15
+ 'mm': 'millimeter',
16
+ 'cm': 'centimeter',
17
+ 'km': 'kilometer'}
18
+
19
+ _dimension_re = re.compile(
20
+ r'\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b|\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b')
21
+ _dimension_key = {'m': 'meter',
22
+ 'in': 'inch',
23
+ 'inch': 'inch'}
24
+
25
+
26
+
27
+
28
+ def _expand_letters_and_numbers(m):
29
+ text = re.split(r'(\d+)', m.group(0))
30
+
31
+ # remove trailing space
32
+ if text[-1] == '':
33
+ text = text[:-1]
34
+ elif text[0] == '':
35
+ text = text[1:]
36
+
37
+ # if not like 1920s, or AK47's , 20th, 1st, 2nd, 3rd, etc...
38
+ if text[-1] in ("'s", "s", "th", "nd", "st", "rd") and text[-2].isdigit():
39
+ text[-2] = text[-2] + text[-1]
40
+ text = text[:-1]
41
+
42
+ # for combining digits 2 by 2
43
+ new_text = []
44
+ for i in range(len(text)):
45
+ string = text[i]
46
+ if string.isdigit() and len(string) < 5:
47
+ # heuristics
48
+ if len(string) > 2 and string[-2] == '0':
49
+ if string[-1] == '0':
50
+ string = [string]
51
+ else:
52
+ string = [string[:-2], string[-2], string[-1]]
53
+ elif len(string) % 2 == 0:
54
+ string = [string[i:i+2] for i in range(0, len(string), 2)]
55
+ elif len(string) > 2:
56
+ string = [string[0]] + [string[i:i+2] for i in range(1, len(string), 2)]
57
+ new_text.extend(string)
58
+ else:
59
+ new_text.append(string)
60
+
61
+ text = new_text
62
+ text = " ".join(text)
63
+ return text
64
+
65
+
66
+ def _expand_hardware(m):
67
+ quantity, measure = m.groups(0)
68
+ measure = _hardware_key[measure.lower()]
69
+ if measure[-1] != 'z' and float(quantity.replace(',', '')) > 1:
70
+ return "{} {}s".format(quantity, measure)
71
+ return "{} {}".format(quantity, measure)
72
+
73
+
74
+ def _expand_dimension(m):
75
+ text = "".join([x for x in m.groups(0) if x != 0])
76
+ text = text.replace(' x ', ' by ')
77
+ text = text.replace('x', ' by ')
78
+ if text.endswith(tuple(_dimension_key.keys())):
79
+ if text[-2].isdigit():
80
+ text = "{} {}".format(text[:-1], _dimension_key[text[-1:]])
81
+ elif text[-3].isdigit():
82
+ text = "{} {}".format(text[:-2], _dimension_key[text[-2:]])
83
+ return text
84
+
85
+
86
+ def normalize_letters_and_numbers(text):
87
+ text = re.sub(_hardware_re, _expand_hardware, text)
88
+ text = re.sub(_dimension_re, _expand_dimension, text)
89
+ text = re.sub(_letters_and_numbers_re, _expand_letters_and_numbers, text)
90
+ return text
common/text/numerical.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ adapted from https://github.com/keithito/tacotron """
2
+
3
+ import inflect
4
+ import re
5
+ _magnitudes = ['trillion', 'billion', 'million', 'thousand', 'hundred', 'm', 'b', 't']
6
+ _magnitudes_key = {'m': 'million', 'b': 'billion', 't': 'trillion'}
7
+ _measurements = '(f|c|k|d|m)'
8
+ _measurements_key = {'f': 'fahrenheit',
9
+ 'c': 'celsius',
10
+ 'k': 'thousand',
11
+ 'm': 'meters'}
12
+ _currency_key = {'$': 'dollar', '£': 'pound', '€': 'euro', '₩': 'won'}
13
+ _inflect = inflect.engine()
14
+ _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
15
+ _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
16
+ _currency_re = re.compile(r'([\$€£₩])([0-9\.\,]*[0-9]+)(?:[ ]?({})(?=[^a-zA-Z]|$))?'.format("|".join(_magnitudes)), re.IGNORECASE)
17
+ _measurement_re = re.compile(r'([0-9\.\,]*[0-9]+(\s)?{}\b)'.format(_measurements), re.IGNORECASE)
18
+ _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
19
+ # _range_re = re.compile(r'(?<=[0-9])+(-)(?=[0-9])+.*?')
20
+ _roman_re = re.compile(r'\b(?=[MDCLXVI]+\b)M{0,4}(CM|CD|D?C{0,3})(XC|XL|L?X{0,3})(IX|IV|V?I{2,3})\b') # avoid I
21
+ _multiply_re = re.compile(r'(\b[0-9]+)(x)([0-9]+)')
22
+ _number_re = re.compile(r"[0-9]+'s|[0-9]+s|[0-9]+")
23
+
24
+ def _remove_commas(m):
25
+ return m.group(1).replace(',', '')
26
+
27
+
28
+ def _expand_decimal_point(m):
29
+ return m.group(1).replace('.', ' point ')
30
+
31
+
32
+ def _expand_currency(m):
33
+ currency = _currency_key[m.group(1)]
34
+ quantity = m.group(2)
35
+ magnitude = m.group(3)
36
+
37
+ # remove commas from quantity to be able to convert to numerical
38
+ quantity = quantity.replace(',', '')
39
+
40
+ # check for million, billion, etc...
41
+ if magnitude is not None and magnitude.lower() in _magnitudes:
42
+ if len(magnitude) == 1:
43
+ magnitude = _magnitudes_key[magnitude.lower()]
44
+ return "{} {} {}".format(_expand_hundreds(quantity), magnitude, currency+'s')
45
+
46
+ parts = quantity.split('.')
47
+ if len(parts) > 2:
48
+ return quantity + " " + currency + "s" # Unexpected format
49
+
50
+ dollars = int(parts[0]) if parts[0] else 0
51
+
52
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
53
+ if dollars and cents:
54
+ dollar_unit = currency if dollars == 1 else currency+'s'
55
+ cent_unit = 'cent' if cents == 1 else 'cents'
56
+ return "{} {}, {} {}".format(
57
+ _expand_hundreds(dollars), dollar_unit,
58
+ _inflect.number_to_words(cents), cent_unit)
59
+ elif dollars:
60
+ dollar_unit = currency if dollars == 1 else currency+'s'
61
+ return "{} {}".format(_expand_hundreds(dollars), dollar_unit)
62
+ elif cents:
63
+ cent_unit = 'cent' if cents == 1 else 'cents'
64
+ return "{} {}".format(_inflect.number_to_words(cents), cent_unit)
65
+ else:
66
+ return 'zero' + ' ' + currency + 's'
67
+
68
+
69
+ def _expand_hundreds(text):
70
+ number = float(text)
71
+ if 1000 < number < 10000 and (number % 100 == 0) and (number % 1000 != 0):
72
+ return _inflect.number_to_words(int(number / 100)) + " hundred"
73
+ else:
74
+ return _inflect.number_to_words(text)
75
+
76
+
77
+ def _expand_ordinal(m):
78
+ return _inflect.number_to_words(m.group(0))
79
+
80
+
81
+ def _expand_measurement(m):
82
+ _, number, measurement = re.split('(\d+(?:\.\d+)?)', m.group(0))
83
+ number = _inflect.number_to_words(number)
84
+ measurement = "".join(measurement.split())
85
+ measurement = _measurements_key[measurement.lower()]
86
+ return "{} {}".format(number, measurement)
87
+
88
+
89
+ def _expand_range(m):
90
+ return ' to '
91
+
92
+
93
+ def _expand_multiply(m):
94
+ left = m.group(1)
95
+ right = m.group(3)
96
+ return "{} by {}".format(left, right)
97
+
98
+
99
+ def _expand_roman(m):
100
+ # from https://stackoverflow.com/questions/19308177/converting-roman-numerals-to-integers-in-python
101
+ roman_numerals = {'I':1, 'V':5, 'X':10, 'L':50, 'C':100, 'D':500, 'M':1000}
102
+ result = 0
103
+ num = m.group(0)
104
+ for i, c in enumerate(num):
105
+ if (i+1) == len(num) or roman_numerals[c] >= roman_numerals[num[i+1]]:
106
+ result += roman_numerals[c]
107
+ else:
108
+ result -= roman_numerals[c]
109
+ return str(result)
110
+
111
+
112
+ def _expand_number(m):
113
+ _, number, suffix = re.split(r"(\d+(?:'?\d+)?)", m.group(0))
114
+ number = int(number)
115
+ if number > 1000 < 10000 and (number % 100 == 0) and (number % 1000 != 0):
116
+ text = _inflect.number_to_words(number // 100) + " hundred"
117
+ elif number > 1000 and number < 3000:
118
+ if number == 2000:
119
+ text = 'two thousand'
120
+ elif number > 2000 and number < 2010:
121
+ text = 'two thousand ' + _inflect.number_to_words(number % 100)
122
+ elif number % 100 == 0:
123
+ text = _inflect.number_to_words(number // 100) + ' hundred'
124
+ else:
125
+ number = _inflect.number_to_words(number, andword='', zero='oh', group=2).replace(', ', ' ')
126
+ number = re.sub(r'-', ' ', number)
127
+ text = number
128
+ else:
129
+ number = _inflect.number_to_words(number, andword='and')
130
+ number = re.sub(r'-', ' ', number)
131
+ number = re.sub(r',', '', number)
132
+ text = number
133
+
134
+ if suffix in ("'s", "s"):
135
+ if text[-1] == 'y':
136
+ text = text[:-1] + 'ies'
137
+ else:
138
+ text = text + suffix
139
+
140
+ return text
141
+
142
+
143
+ def normalize_numbers(text):
144
+ text = re.sub(_comma_number_re, _remove_commas, text)
145
+ text = re.sub(_currency_re, _expand_currency, text)
146
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
147
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
148
+ # text = re.sub(_range_re, _expand_range, text)
149
+ # text = re.sub(_measurement_re, _expand_measurement, text)
150
+ text = re.sub(_roman_re, _expand_roman, text)
151
+ text = re.sub(_multiply_re, _expand_multiply, text)
152
+ text = re.sub(_number_re, _expand_number, text)
153
+ return text
common/text/symbols.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ '''
4
+ Defines the set of symbols used in text input to the model.
5
+
6
+ The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
7
+ from .cmudict import valid_symbols
8
+
9
+
10
+ # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
11
+ _arpabet = ['@' + s for s in valid_symbols]
12
+
13
+
14
+ def get_symbols(symbol_set='english_basic'):
15
+ if symbol_set == 'english_basic':
16
+ _pad = '_'
17
+ _punctuation = '!\'(),.:;? '
18
+ _special = '-'
19
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
20
+ symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
21
+ elif symbol_set == 'english_basic_lowercase':
22
+ _pad = '_'
23
+ _punctuation = '!\'"(),.:;? '
24
+ _special = '-'
25
+ _letters = 'abcdefghijklmnopqrstuvwxyz'
26
+ symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
27
+ elif symbol_set == 'english_expanded':
28
+ _punctuation = '!\'",.:;? '
29
+ _math = '#%&*+-/[]()'
30
+ _special = '_@©°½—₩€$'
31
+ _accented = 'áçéêëñöøćž'
32
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
33
+ symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
34
+ elif symbol_set == 'smj_expanded':
35
+ _punctuation = '!\'",.:;?- '
36
+ _math = '#%&*+-/[]()'
37
+ _special = '_@©°½—₩€$'
38
+ # _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
39
+ _accented = 'áçéêëñöø' #also north sámi letters...
40
+ # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
41
+ _letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTŦUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz' ########################## Ŧ ########################
42
+ # symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
43
+ symbols = list(_punctuation + _letters) + _arpabet
44
+ elif symbol_set == 'sme_expanded':
45
+ _punctuation = '!\'",.:;?- '
46
+ _math = '#%&*+-/[]()'
47
+ _special = '_@©°½—₩€$'
48
+ _accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
49
+ # _accented = 'áçéêëñöø' #also north sámi letters...
50
+ # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
51
+ _letters = 'AÁÆÅÄBCČDĐEFGHIJKLMNŊOØÖPQRSŠTŦUVWXYZŽaáæåäbcčdđefghijklmnŋoøöpqrsštŧuvwxyzž'
52
+ # symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
53
+ symbols = list(_punctuation + _letters) + _arpabet
54
+ elif symbol_set == 'sma_expanded':
55
+ _punctuation = '!\'",.:;?- '
56
+ _math = '#%&*+-/[]()'
57
+ _special = '_@©°½—₩€$'
58
+ _accented = 'áäæçéêëïńñöøćčžđšŧ' #also north sámi letters...
59
+ # _accented = 'áçéêëñöø' #also north sámi letters...
60
+ # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
61
+ _letters = 'AÆÅBCDEFGHIÏJKLMNOØÖPQRSTUVWXYZaæåbcdefghiïjklmnoøöpqrstuvwxyz'
62
+ # symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
63
+ symbols = list(_punctuation + _letters) + _arpabet
64
+ elif symbol_set == 'all_sami':
65
+ _punctuation = '!\'",.:;?- '
66
+ _math = '#%&*+-/[]()'
67
+ _special = '_@©°½—₩€$'
68
+ _accented = 'áäæçéêëïńñöøćčžđšŧ'
69
+ _letters = 'AÁÆÅÄBCČDĐEFGHIÏJKLMNŊŃÑOØÖPQRSŠTŦUVWXYZŽaáæåäbcčdđefghiïjklmnŋńñoøöpqrsštŧuvwxyzž'
70
+ symbols = list(_punctuation + _letters)# + _arpabet
71
+ else:
72
+ raise Exception("{} symbol set does not exist".format(symbol_set))
73
+
74
+ return symbols
75
+
76
+
77
+ def get_pad_idx(symbol_set='english_basic'):
78
+ if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded', 'sma_expanded', 'all_sami'}:
79
+ return 0
80
+ else:
81
+ raise Exception("{} symbol set not used yet".format(symbol_set))