Spaces:
Sleeping
Sleeping
katrihiovain
commited on
Commit
·
95b5cf1
1
Parent(s):
2b63853
removed unnecessary files and updated app_py
Browse files- README1.md +0 -15
- app.py +1 -1
- common/text/symbols_2.py +0 -64
- common/text/symbols_BACKUP_MARCH_2024.py +0 -64
- common/text/symbols_ORIGINAL.py +0 -65
- common/text/symbols_backup.py +0 -54
- common/text/symbols_sme.py +0 -64
- common/text/symbols_sme_1.py +0 -64
- common/text/symbols_smj.py +0 -48
- common/utils_hfg.py +0 -14
- common/utils_ok.py +0 -291
- fastpitch/data_function (copy).py.txt +0 -425
- fastpitch/data_function_model_py.zip +0 -3
- fastpitch/utils_trainplot_transformers.zip +0 -3
- fastpitch/utils_trainplot_transformers/train_1_with_plot.py +0 -591
- fastpitch/utils_trainplot_transformers/transformer.py +0 -213
- fastpitch/utils_trainplot_transformers/transformer_jit.py +0 -255
- fastpitch/utils_trainplot_transformers/utils.py +0 -291
- gradio_gui.py +0 -74
- gradio_gui_katri.py +0 -73
- prepare_dataset.py +0 -180
- run_training_cluster_s.sh +0 -33
- scripts/docker/build.sh +0 -3
- scripts/docker/interactive.sh +0 -5
- scripts/download_cmudict.sh +0 -10
- scripts/download_dataset.sh +0 -17
- scripts/download_models.sh +0 -63
- scripts/inference_benchmark.sh +0 -16
- scripts/inference_example.sh +0 -78
- scripts/prepare_dataset.sh +0 -19
- scripts/train.sh +0 -100
- scripts/train_multilang.sh +0 -110
- train_1_with_plot_multilang.py +0 -593
README1.md
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
# FastPitchMulti
|
2 |
-
Experimental multi-lingual FastPitch
|
3 |
-
|
4 |
-
What's done:
|
5 |
-
- [x] Conditioning on language and speaker labels
|
6 |
-
- [x] Dataset and preprocessing of Sámi data
|
7 |
-
- [x] Combined character set for the Sámi languages
|
8 |
-
- [x] Train a model on Sámi languages
|
9 |
-
- [x] Selecting Estonian data
|
10 |
-
- [x] Processing Estonian data
|
11 |
-
- [ ] Train a model on Sámi x 3, Finnish, Estonian
|
12 |
-
|
13 |
-
Ideas:
|
14 |
-
- Move the language embedding to the very beginning of the encoder
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -11,7 +11,7 @@ speakers={"aj0": 0,
|
|
11 |
"aj1": 1,
|
12 |
"am": 2,
|
13 |
"bi": 3,
|
14 |
-
"kd": 4,
|
15 |
"ln": 5,
|
16 |
"lo": 6,
|
17 |
"ms": 7,
|
|
|
11 |
"aj1": 1,
|
12 |
"am": 2,
|
13 |
"bi": 3,
|
14 |
+
#"kd": 4,
|
15 |
"ln": 5,
|
16 |
"lo": 6,
|
17 |
"ms": 7,
|
common/text/symbols_2.py
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
""" from https://github.com/keithito/tacotron """
|
2 |
-
|
3 |
-
'''
|
4 |
-
Defines the set of symbols used in text input to the model.
|
5 |
-
|
6 |
-
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
7 |
-
from .cmudict import valid_symbols
|
8 |
-
|
9 |
-
|
10 |
-
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
11 |
-
_arpabet = ['@' + s for s in valid_symbols]
|
12 |
-
|
13 |
-
|
14 |
-
def get_symbols(symbol_set='english_basic'):
|
15 |
-
if symbol_set == 'english_basic':
|
16 |
-
_pad = '_'
|
17 |
-
_punctuation = '!\'(),.:;? '
|
18 |
-
_special = '-'
|
19 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
20 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
21 |
-
elif symbol_set == 'english_basic_lowercase':
|
22 |
-
_pad = '_'
|
23 |
-
_punctuation = '!\'"(),.:;? '
|
24 |
-
_special = '-'
|
25 |
-
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
26 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
27 |
-
elif symbol_set == 'english_expanded':
|
28 |
-
_punctuation = '!\'",.:;? '
|
29 |
-
_math = '#%&*+-/[]()'
|
30 |
-
_special = '_@©°½—₩€$'
|
31 |
-
_accented = 'áçéêëñöøćž'
|
32 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
33 |
-
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
34 |
-
elif symbol_set == 'smj_expanded':
|
35 |
-
_punctuation = '!\'",.:;?- '
|
36 |
-
_math = '#%&*+-/[]()'
|
37 |
-
_special = '_@©°½—₩€$'
|
38 |
-
# _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
39 |
-
_accented = 'áçéêëñöø' #also north sámi letters...
|
40 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
41 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
42 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
43 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
44 |
-
elif symbol_set == 'sme_expanded':
|
45 |
-
_punctuation = '!\'",.:;?- '
|
46 |
-
_math = '#%&*+-/[]()'
|
47 |
-
_special = '_@©°½—₩€$'
|
48 |
-
_accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
|
49 |
-
# _accented = 'áçéêëñöø' #also north sámi letters...
|
50 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
51 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋoøöpqrstuvwxyz'
|
52 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
53 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
54 |
-
else:
|
55 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
56 |
-
|
57 |
-
return symbols
|
58 |
-
|
59 |
-
|
60 |
-
def get_pad_idx(symbol_set='english_basic'):
|
61 |
-
if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
|
62 |
-
return 0
|
63 |
-
else:
|
64 |
-
raise Exception("{} symbol set not used yet".format(symbol_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/text/symbols_BACKUP_MARCH_2024.py
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
""" from https://github.com/keithito/tacotron """
|
2 |
-
|
3 |
-
'''
|
4 |
-
Defines the set of symbols used in text input to the model.
|
5 |
-
|
6 |
-
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
7 |
-
from .cmudict import valid_symbols
|
8 |
-
|
9 |
-
|
10 |
-
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
11 |
-
_arpabet = ['@' + s for s in valid_symbols]
|
12 |
-
|
13 |
-
|
14 |
-
def get_symbols(symbol_set='english_basic'):
|
15 |
-
if symbol_set == 'english_basic':
|
16 |
-
_pad = '_'
|
17 |
-
_punctuation = '!\'(),.:;? '
|
18 |
-
_special = '-'
|
19 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
20 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
21 |
-
elif symbol_set == 'english_basic_lowercase':
|
22 |
-
_pad = '_'
|
23 |
-
_punctuation = '!\'"(),.:;? '
|
24 |
-
_special = '-'
|
25 |
-
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
26 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
27 |
-
elif symbol_set == 'english_expanded':
|
28 |
-
_punctuation = '!\'",.:;? '
|
29 |
-
_math = '#%&*+-/[]()'
|
30 |
-
_special = '_@©°½—₩€$'
|
31 |
-
_accented = 'áçéêëñöøćž'
|
32 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
33 |
-
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
34 |
-
elif symbol_set == 'smj_expanded':
|
35 |
-
_punctuation = '!\'",.:;?- '
|
36 |
-
_math = '#%&*+-/[]()'
|
37 |
-
_special = '_@©°½—₩€$'
|
38 |
-
# _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
39 |
-
_accented = 'áçéêëñöø' #also north sámi letters...
|
40 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
41 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTŦUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz' ########################## Ŧ ########################
|
42 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
43 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
44 |
-
elif symbol_set == 'sme_expanded':
|
45 |
-
_punctuation = '!\'",.:;?- '
|
46 |
-
_math = '#%&*+-/[]()'
|
47 |
-
_special = '_@©°½—₩€$'
|
48 |
-
_accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
|
49 |
-
# _accented = 'áçéêëñöø' #also north sámi letters...
|
50 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
51 |
-
_letters = 'AÁÆÅÄBCČDĐEFGHIJKLMNŊOØÖPQRSŠTŦUVWXYZŽaáæåäbcčdđefghijklmnŋoøöpqrsštŧuvwxyzž'
|
52 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
53 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
54 |
-
else:
|
55 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
56 |
-
|
57 |
-
return symbols
|
58 |
-
|
59 |
-
|
60 |
-
def get_pad_idx(symbol_set='english_basic'):
|
61 |
-
if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
|
62 |
-
return 0
|
63 |
-
else:
|
64 |
-
raise Exception("{} symbol set not used yet".format(symbol_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/text/symbols_ORIGINAL.py
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
""" from https://github.com/keithito/tacotron """
|
2 |
-
|
3 |
-
'''
|
4 |
-
Defines the set of symbols used in text input to the model.
|
5 |
-
|
6 |
-
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
7 |
-
from .cmudict import valid_symbols
|
8 |
-
|
9 |
-
|
10 |
-
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
11 |
-
_arpabet = ['@' + s for s in valid_symbols]
|
12 |
-
|
13 |
-
|
14 |
-
def get_symbols(symbol_set='english_basic'):
|
15 |
-
if symbol_set == 'english_basic':
|
16 |
-
_pad = '_'
|
17 |
-
_punctuation = '!\'(),.:;? '
|
18 |
-
_special = '-'
|
19 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
20 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
21 |
-
elif symbol_set == 'english_basic_lowercase':
|
22 |
-
_pad = '_'
|
23 |
-
_punctuation = '!\'"(),.:;? '
|
24 |
-
_special = '-'
|
25 |
-
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
26 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
27 |
-
elif symbol_set == 'english_expanded':
|
28 |
-
_punctuation = '!\'",.:;? '
|
29 |
-
_math = '#%&*+-/[]()'
|
30 |
-
_special = '_@©°½—₩€$'
|
31 |
-
_accented = 'áçéêëñöøćž'
|
32 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
33 |
-
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
34 |
-
elif symbol_set == 'smj_expanded':
|
35 |
-
_punctuation = '!\'",.:;?- '
|
36 |
-
_math = '#%&*+-/[]()'
|
37 |
-
_special = '_@©°½—₩€$'
|
38 |
-
# _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
39 |
-
_accented = 'áçéêëñöø' #also north sámi letters...
|
40 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
41 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTŦUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
42 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
43 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
44 |
-
elif symbol_set == 'sme_expanded':
|
45 |
-
_punctuation = '!\'",.:;?- '
|
46 |
-
_math = '#%&*+-/[]()'
|
47 |
-
_special = '_@©°½—₩€$'
|
48 |
-
_accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
|
49 |
-
# _accented = 'áçéêëñöø' #also north sámi letters...
|
50 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
51 |
-
_letters = 'AÁÆÅÄBCČDĐEFGHIJKLMNŊOØÖPQRSŠTŦUVWXYZŽaáæåäbcčdđefghijklmnŋoøöpqrsštŧuvwxyzž'
|
52 |
-
# _letters = 'AÁÆÅÄBCDĐEFGHIJKLMNŊOØÖPQRSŠTŦUVWXYZŽaáæåäbcčdđefghijklmnŋoøöpqrsštŧuvwxyzž'
|
53 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
54 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
55 |
-
else:
|
56 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
57 |
-
|
58 |
-
return symbols
|
59 |
-
|
60 |
-
|
61 |
-
def get_pad_idx(symbol_set='english_basic'):
|
62 |
-
if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
|
63 |
-
return 0
|
64 |
-
else:
|
65 |
-
raise Exception("{} symbol set not used yet".format(symbol_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/text/symbols_backup.py
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
""" from https://github.com/keithito/tacotron """
|
2 |
-
|
3 |
-
'''
|
4 |
-
Defines the set of symbols used in text input to the model.
|
5 |
-
|
6 |
-
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
7 |
-
from .cmudict import valid_symbols
|
8 |
-
|
9 |
-
|
10 |
-
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
11 |
-
_arpabet = ['@' + s for s in valid_symbols]
|
12 |
-
|
13 |
-
|
14 |
-
def get_symbols(symbol_set='english_basic'):
|
15 |
-
if symbol_set == 'english_basic':
|
16 |
-
_pad = '_'
|
17 |
-
_punctuation = '!\'(),.:;? '
|
18 |
-
_special = '-'
|
19 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
20 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
21 |
-
elif symbol_set == 'english_basic_lowercase':
|
22 |
-
_pad = '_'
|
23 |
-
_punctuation = '!\'"(),.:;? '
|
24 |
-
_special = '-'
|
25 |
-
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
26 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
27 |
-
elif symbol_set == 'english_expanded':
|
28 |
-
_punctuation = '!\'",.:;? '
|
29 |
-
_math = '#%&*+-/[]()'
|
30 |
-
_special = '_@©°½—₩€$'
|
31 |
-
_accented = 'áçéêëñöøćž'
|
32 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
33 |
-
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
34 |
-
elif symbol_set == 'smj_expanded':
|
35 |
-
_punctuation = '!\'",.:;?- '
|
36 |
-
_math = '#%&*+-/[]()'
|
37 |
-
_special = '_@©°½—₩€$'
|
38 |
-
# _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
39 |
-
_accented = 'áçéêëñöø' #also north sámi letters...
|
40 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
41 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
42 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
43 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
44 |
-
else:
|
45 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
46 |
-
|
47 |
-
return symbols
|
48 |
-
|
49 |
-
|
50 |
-
def get_pad_idx(symbol_set='english_basic'):
|
51 |
-
if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded'}:
|
52 |
-
return 0
|
53 |
-
else:
|
54 |
-
raise Exception("{} symbol set not used yet".format(symbol_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/text/symbols_sme.py
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
""" from https://github.com/keithito/tacotron """
|
2 |
-
|
3 |
-
'''
|
4 |
-
Defines the set of symbols used in text input to the model.
|
5 |
-
|
6 |
-
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
7 |
-
from .cmudict import valid_symbols
|
8 |
-
|
9 |
-
|
10 |
-
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
11 |
-
_arpabet = ['@' + s for s in valid_symbols]
|
12 |
-
|
13 |
-
|
14 |
-
def get_symbols(symbol_set='english_basic'):
|
15 |
-
if symbol_set == 'english_basic':
|
16 |
-
_pad = '_'
|
17 |
-
_punctuation = '!\'(),.:;? '
|
18 |
-
_special = '-'
|
19 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
20 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
21 |
-
elif symbol_set == 'english_basic_lowercase':
|
22 |
-
_pad = '_'
|
23 |
-
_punctuation = '!\'"(),.:;? '
|
24 |
-
_special = '-'
|
25 |
-
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
26 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
27 |
-
elif symbol_set == 'english_expanded':
|
28 |
-
_punctuation = '!\'",.:;? '
|
29 |
-
_math = '#%&*+-/[]()'
|
30 |
-
_special = '_@©°½—₩€$'
|
31 |
-
_accented = 'áçéêëñöøćž'
|
32 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
33 |
-
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
34 |
-
elif symbol_set == 'smj_expanded':
|
35 |
-
_punctuation = '!\'",.:;?- '
|
36 |
-
_math = '#%&*+-/[]()'
|
37 |
-
_special = '_@©°½—₩€$'
|
38 |
-
# _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
39 |
-
_accented = 'áçéêëñöø' #also north sámi letters...
|
40 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
41 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
42 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
43 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
44 |
-
elif symbol_set == 'sme_expanded':
|
45 |
-
_punctuation = '!\'",.:;?- '
|
46 |
-
_math = '#%&*+-/[]()'
|
47 |
-
_special = '_@©°½—₩€$'
|
48 |
-
_accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
|
49 |
-
# _accented = 'áçéêëñöø' #also north sámi letters...
|
50 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
51 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋoøöpqrstuvwxyz'
|
52 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
53 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
54 |
-
else:
|
55 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
56 |
-
|
57 |
-
return symbols
|
58 |
-
|
59 |
-
|
60 |
-
def get_pad_idx(symbol_set='english_basic'):
|
61 |
-
if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
|
62 |
-
return 0
|
63 |
-
else:
|
64 |
-
raise Exception("{} symbol set not used yet".format(symbol_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/text/symbols_sme_1.py
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
""" from https://github.com/keithito/tacotron """
|
2 |
-
|
3 |
-
'''
|
4 |
-
Defines the set of symbols used in text input to the model.
|
5 |
-
|
6 |
-
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
7 |
-
from .cmudict import valid_symbols
|
8 |
-
|
9 |
-
|
10 |
-
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
11 |
-
_arpabet = ['@' + s for s in valid_symbols]
|
12 |
-
|
13 |
-
|
14 |
-
def get_symbols(symbol_set='english_basic'):
|
15 |
-
if symbol_set == 'english_basic':
|
16 |
-
_pad = '_'
|
17 |
-
_punctuation = '!\'(),.:;? '
|
18 |
-
_special = '-'
|
19 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
20 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
21 |
-
elif symbol_set == 'english_basic_lowercase':
|
22 |
-
_pad = '_'
|
23 |
-
_punctuation = '!\'"(),.:;? '
|
24 |
-
_special = '-'
|
25 |
-
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
26 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
27 |
-
elif symbol_set == 'english_expanded':
|
28 |
-
_punctuation = '!\'",.:;? '
|
29 |
-
_math = '#%&*+-/[]()'
|
30 |
-
_special = '_@©°½—₩€$'
|
31 |
-
_accented = 'áçéêëñöøćž'
|
32 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
33 |
-
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
34 |
-
elif symbol_set == 'smj_expanded':
|
35 |
-
_punctuation = '!\'",.:;?- '
|
36 |
-
_math = '#%&*+-/[]()'
|
37 |
-
_special = '_@©°½—₩€$'
|
38 |
-
# _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
39 |
-
_accented = 'áçéêëñöø' #also north sámi letters...
|
40 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
41 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
42 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
43 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
44 |
-
elif symbol_set == 'sme_expanded':
|
45 |
-
_punctuation = '!\'",.:;?- '
|
46 |
-
_math = '#%&*+-/[]()'
|
47 |
-
_special = '_@©°½—₩€$'
|
48 |
-
_accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
|
49 |
-
# _accented = 'áçéêëñöø' #also north sámi letters...
|
50 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
51 |
-
_letters = 'AÁÆÅÄBCDĐEFGHIJKLMNŊOØÖPQRSŠTUVWXYZŽaáæåäbcdđefghijklmnŋoøöpqrsštŧuvwxyzž'
|
52 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
53 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
54 |
-
else:
|
55 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
56 |
-
|
57 |
-
return symbols
|
58 |
-
|
59 |
-
|
60 |
-
def get_pad_idx(symbol_set='english_basic'):
|
61 |
-
if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
|
62 |
-
return 0
|
63 |
-
else:
|
64 |
-
raise Exception("{} symbol set not used yet".format(symbol_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/text/symbols_smj.py
DELETED
@@ -1,48 +0,0 @@
|
|
1 |
-
""" from https://github.com/keithito/tacotron """
|
2 |
-
|
3 |
-
'''
|
4 |
-
Defines the set of symbols used in text input to the model.
|
5 |
-
|
6 |
-
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
7 |
-
from .cmudict import valid_symbols
|
8 |
-
|
9 |
-
|
10 |
-
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
11 |
-
_arpabet = ['@' + s for s in valid_symbols]
|
12 |
-
|
13 |
-
|
14 |
-
def get_symbols(symbol_set='smj_basic'):
|
15 |
-
if symbol_set == 'smj_basic':
|
16 |
-
_pad = '_'
|
17 |
-
_punctuation = '!\'(),.:;? '
|
18 |
-
_special = '-'
|
19 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
20 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
21 |
-
# LETTERS = 'AaÁáBbCcDdEeFfGgHhIiJjKkLlMmNnŊŋOoPpRrSsTtUuVvZzŃńÑñÆæØøÅåÄäÖö'
|
22 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
23 |
-
elif symbol_set == 'smj_basic_lowercase':
|
24 |
-
_pad = '_'
|
25 |
-
_punctuation = '!\'"(),.:;? '
|
26 |
-
_special = '-'
|
27 |
-
# _letters = 'abcdefghijklmnopqrstuvwxyz'
|
28 |
-
_letters = 'aáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
29 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
30 |
-
elif symbol_set == 'smj_expanded':
|
31 |
-
_punctuation = '!\'",.:;? '
|
32 |
-
_math = '#%&*+-/[]()'
|
33 |
-
_special = '_@©°½—₩€$'
|
34 |
-
_accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
35 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
36 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
37 |
-
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
38 |
-
else:
|
39 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
40 |
-
|
41 |
-
return symbols
|
42 |
-
|
43 |
-
|
44 |
-
def get_pad_idx(symbol_set='smj_basic'):
|
45 |
-
if symbol_set in {'smj_basic', 'smj_basic_lowercase'}:
|
46 |
-
return 0
|
47 |
-
else:
|
48 |
-
raise Exception("{} symbol set not used yet".format(symbol_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/utils_hfg.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
|
2 |
-
##############################################################################
|
3 |
-
# Foreing utils.py from HiFi-GAN
|
4 |
-
##############################################################################
|
5 |
-
|
6 |
-
|
7 |
-
def init_weights(m, mean=0.0, std=0.01):
|
8 |
-
classname = m.__class__.__name__
|
9 |
-
if classname.find("Conv") != -1:
|
10 |
-
m.weight.data.normal_(mean, std)
|
11 |
-
|
12 |
-
|
13 |
-
def get_padding(kernel_size, dilation=1):
|
14 |
-
return int((kernel_size*dilation - dilation)/2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/utils_ok.py
DELETED
@@ -1,291 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
# MIT License
|
16 |
-
#
|
17 |
-
# Copyright (c) 2020 Jungil Kong
|
18 |
-
#
|
19 |
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
20 |
-
# of this software and associated documentation files (the "Software"), to deal
|
21 |
-
# in the Software without restriction, including without limitation the rights
|
22 |
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
23 |
-
# copies of the Software, and to permit persons to whom the Software is
|
24 |
-
# furnished to do so, subject to the following conditions:
|
25 |
-
#
|
26 |
-
# The above copyright notice and this permission notice shall be included in all
|
27 |
-
# copies or substantial portions of the Software.
|
28 |
-
#
|
29 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
30 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
31 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
32 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
33 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
34 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
35 |
-
# SOFTWARE.
|
36 |
-
|
37 |
-
# The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
|
38 |
-
# init_weights, get_padding, AttrDict
|
39 |
-
|
40 |
-
import ctypes
|
41 |
-
import glob
|
42 |
-
import os
|
43 |
-
import re
|
44 |
-
import shutil
|
45 |
-
import warnings
|
46 |
-
from collections import defaultdict, OrderedDict
|
47 |
-
from pathlib import Path
|
48 |
-
from typing import Optional
|
49 |
-
|
50 |
-
import librosa
|
51 |
-
import numpy as np
|
52 |
-
|
53 |
-
import torch
|
54 |
-
import torch.distributed as dist
|
55 |
-
from scipy.io.wavfile import read
|
56 |
-
|
57 |
-
|
58 |
-
def mask_from_lens(lens, max_len: Optional[int] = None):
|
59 |
-
if max_len is None:
|
60 |
-
max_len = lens.max()
|
61 |
-
ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
|
62 |
-
mask = torch.lt(ids, lens.unsqueeze(1))
|
63 |
-
return mask
|
64 |
-
|
65 |
-
|
66 |
-
def load_wav(full_path, torch_tensor=False):
|
67 |
-
import soundfile # flac
|
68 |
-
data, sampling_rate = soundfile.read(full_path, dtype='int16')
|
69 |
-
if torch_tensor:
|
70 |
-
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
71 |
-
else:
|
72 |
-
return data, sampling_rate
|
73 |
-
|
74 |
-
|
75 |
-
def load_wav_to_torch(full_path, force_sampling_rate=None):
|
76 |
-
if force_sampling_rate is not None:
|
77 |
-
data, sampling_rate = librosa.load(full_path, sr=force_sampling_rate)
|
78 |
-
else:
|
79 |
-
sampling_rate, data = read(full_path)
|
80 |
-
|
81 |
-
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
82 |
-
|
83 |
-
|
84 |
-
def load_filepaths_and_text(dataset_path, fnames, has_speakers=False, split="|"):
|
85 |
-
def split_line(root, line):
|
86 |
-
parts = line.strip().split(split)
|
87 |
-
if has_speakers:
|
88 |
-
paths, non_paths = parts[:-2], parts[-2:]
|
89 |
-
else:
|
90 |
-
paths, non_paths = parts[:-1], parts[-1:]
|
91 |
-
return tuple(str(Path(root, p)) for p in paths) + tuple(non_paths)
|
92 |
-
|
93 |
-
fpaths_and_text = []
|
94 |
-
for fname in fnames:
|
95 |
-
with open(fname, encoding='utf-8') as f:
|
96 |
-
fpaths_and_text += [split_line(dataset_path, line) for line in f]
|
97 |
-
return fpaths_and_text
|
98 |
-
|
99 |
-
|
100 |
-
def to_gpu(x):
|
101 |
-
x = x.contiguous()
|
102 |
-
return x.cuda(non_blocking=True) if torch.cuda.is_available() else x
|
103 |
-
|
104 |
-
|
105 |
-
def l2_promote():
|
106 |
-
_libcudart = ctypes.CDLL('libcudart.so')
|
107 |
-
# Set device limit on the current device
|
108 |
-
# cudaLimitMaxL2FetchGranularity = 0x05
|
109 |
-
pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))
|
110 |
-
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
|
111 |
-
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
|
112 |
-
assert pValue.contents.value == 128
|
113 |
-
|
114 |
-
|
115 |
-
def prepare_tmp(path):
|
116 |
-
if path is None:
|
117 |
-
return
|
118 |
-
p = Path(path)
|
119 |
-
if p.is_dir():
|
120 |
-
warnings.warn(f'{p} exists. Removing...')
|
121 |
-
shutil.rmtree(p, ignore_errors=True)
|
122 |
-
p.mkdir(parents=False, exist_ok=False)
|
123 |
-
|
124 |
-
|
125 |
-
def print_once(*msg):
|
126 |
-
if not dist.is_initialized() or dist.get_rank() == 0:
|
127 |
-
print(*msg)
|
128 |
-
|
129 |
-
|
130 |
-
def init_weights(m, mean=0.0, std=0.01):
|
131 |
-
classname = m.__class__.__name__
|
132 |
-
if classname.find("Conv") != -1:
|
133 |
-
m.weight.data.normal_(mean, std)
|
134 |
-
|
135 |
-
|
136 |
-
def get_padding(kernel_size, dilation=1):
|
137 |
-
return int((kernel_size*dilation - dilation)/2)
|
138 |
-
|
139 |
-
|
140 |
-
class AttrDict(dict):
|
141 |
-
def __init__(self, *args, **kwargs):
|
142 |
-
super(AttrDict, self).__init__(*args, **kwargs)
|
143 |
-
self.__dict__ = self
|
144 |
-
|
145 |
-
|
146 |
-
class DefaultAttrDict(defaultdict):
|
147 |
-
def __init__(self, *args, **kwargs):
|
148 |
-
super(DefaultAttrDict, self).__init__(*args, **kwargs)
|
149 |
-
self.__dict__ = self
|
150 |
-
|
151 |
-
def __getattr__(self, item):
|
152 |
-
return self[item]
|
153 |
-
|
154 |
-
|
155 |
-
class BenchmarkStats:
|
156 |
-
""" Tracks statistics used for benchmarking. """
|
157 |
-
def __init__(self):
|
158 |
-
self.num_frames = []
|
159 |
-
self.losses = []
|
160 |
-
self.mel_losses = []
|
161 |
-
self.took = []
|
162 |
-
|
163 |
-
def update(self, num_frames, losses, mel_losses, took):
|
164 |
-
self.num_frames.append(num_frames)
|
165 |
-
self.losses.append(losses)
|
166 |
-
self.mel_losses.append(mel_losses)
|
167 |
-
self.took.append(took)
|
168 |
-
|
169 |
-
def get(self, n_epochs):
|
170 |
-
frames_s = sum(self.num_frames[-n_epochs:]) / sum(self.took[-n_epochs:])
|
171 |
-
return {'frames/s': frames_s,
|
172 |
-
'loss': np.mean(self.losses[-n_epochs:]),
|
173 |
-
'mel_loss': np.mean(self.mel_losses[-n_epochs:]),
|
174 |
-
'took': np.mean(self.took[-n_epochs:]),
|
175 |
-
'benchmark_epochs_num': n_epochs}
|
176 |
-
|
177 |
-
def __len__(self):
|
178 |
-
return len(self.losses)
|
179 |
-
|
180 |
-
|
181 |
-
class Checkpointer:
|
182 |
-
|
183 |
-
def __init__(self, save_dir, keep_milestones=[]):
|
184 |
-
self.save_dir = save_dir
|
185 |
-
self.keep_milestones = keep_milestones
|
186 |
-
|
187 |
-
find = lambda name: [
|
188 |
-
(int(re.search("_(\d+).pt", fn).group(1)), fn)
|
189 |
-
for fn in glob.glob(f"{save_dir}/{name}_checkpoint_*.pt")]
|
190 |
-
|
191 |
-
tracked = sorted(find("FastPitch"), key=lambda t: t[0])
|
192 |
-
self.tracked = OrderedDict(tracked)
|
193 |
-
|
194 |
-
def last_checkpoint(self, output):
|
195 |
-
|
196 |
-
def corrupted(fpath):
|
197 |
-
try:
|
198 |
-
torch.load(fpath, map_location="cpu")
|
199 |
-
return False
|
200 |
-
except:
|
201 |
-
warnings.warn(f"Cannot load {fpath}")
|
202 |
-
return True
|
203 |
-
|
204 |
-
saved = sorted(
|
205 |
-
glob.glob(f"{output}/FastPitch_checkpoint_*.pt"),
|
206 |
-
key=lambda f: int(re.search("_(\d+).pt", f).group(1)))
|
207 |
-
|
208 |
-
if len(saved) >= 1 and not corrupted(saved[-1]):
|
209 |
-
return saved[-1]
|
210 |
-
elif len(saved) >= 2:
|
211 |
-
return saved[-2]
|
212 |
-
else:
|
213 |
-
return None
|
214 |
-
|
215 |
-
def maybe_load(self, model, optimizer, scaler, train_state, args,
|
216 |
-
ema_model=None):
|
217 |
-
|
218 |
-
assert args.checkpoint_path is None or args.resume is False, (
|
219 |
-
"Specify a single checkpoint source")
|
220 |
-
|
221 |
-
fpath = None
|
222 |
-
if args.checkpoint_path is not None:
|
223 |
-
fpath = args.checkpoint_path
|
224 |
-
self.tracked = OrderedDict() # Do not track/delete prev ckpts
|
225 |
-
elif args.resume:
|
226 |
-
fpath = self.last_checkpoint(args.output)
|
227 |
-
|
228 |
-
if fpath is None:
|
229 |
-
return
|
230 |
-
|
231 |
-
print_once(f"Loading model and optimizer state from {fpath}")
|
232 |
-
ckpt = torch.load(fpath, map_location="cpu")
|
233 |
-
train_state["epoch"] = ckpt["epoch"] + 1
|
234 |
-
train_state["total_iter"] = ckpt["iteration"]
|
235 |
-
|
236 |
-
no_pref = lambda sd: {re.sub("^module.", "", k): v for k, v in sd.items()}
|
237 |
-
unwrap = lambda m: getattr(m, "module", m)
|
238 |
-
|
239 |
-
unwrap(model).load_state_dict(no_pref(ckpt["state_dict"]))
|
240 |
-
|
241 |
-
if ema_model is not None:
|
242 |
-
unwrap(ema_model).load_state_dict(no_pref(ckpt["ema_state_dict"]))
|
243 |
-
|
244 |
-
optimizer.load_state_dict(ckpt["optimizer"])
|
245 |
-
|
246 |
-
if "scaler" in ckpt:
|
247 |
-
scaler.load_state_dict(ckpt["scaler"])
|
248 |
-
else:
|
249 |
-
warnings.warn("AMP scaler state missing from the checkpoint.")
|
250 |
-
|
251 |
-
def maybe_save(self, args, model, ema_model, optimizer, scaler, epoch,
|
252 |
-
total_iter, config):
|
253 |
-
|
254 |
-
intermediate = (args.epochs_per_checkpoint > 0
|
255 |
-
and epoch % args.epochs_per_checkpoint == 0)
|
256 |
-
final = epoch == args.epochs
|
257 |
-
|
258 |
-
if not intermediate and not final and epoch not in self.keep_milestones:
|
259 |
-
return
|
260 |
-
|
261 |
-
rank = 0
|
262 |
-
if dist.is_initialized():
|
263 |
-
dist.barrier()
|
264 |
-
rank = dist.get_rank()
|
265 |
-
|
266 |
-
if rank != 0:
|
267 |
-
return
|
268 |
-
|
269 |
-
unwrap = lambda m: getattr(m, "module", m)
|
270 |
-
ckpt = {"epoch": epoch,
|
271 |
-
"iteration": total_iter,
|
272 |
-
"config": config,
|
273 |
-
"train_setup": args.__dict__,
|
274 |
-
"state_dict": unwrap(model).state_dict(),
|
275 |
-
"optimizer": optimizer.state_dict(),
|
276 |
-
"scaler": scaler.state_dict()}
|
277 |
-
if ema_model is not None:
|
278 |
-
ckpt["ema_state_dict"] = unwrap(ema_model).state_dict()
|
279 |
-
|
280 |
-
fpath = Path(args.output, f"FastPitch_checkpoint_{epoch}.pt")
|
281 |
-
print(f"Saving model and optimizer state at epoch {epoch} to {fpath}")
|
282 |
-
torch.save(ckpt, fpath)
|
283 |
-
|
284 |
-
# Remove old checkpoints; keep milestones and the last two
|
285 |
-
self.tracked[epoch] = fpath
|
286 |
-
for epoch in set(list(self.tracked)[:-2]) - set(self.keep_milestones):
|
287 |
-
try:
|
288 |
-
os.remove(self.tracked[epoch])
|
289 |
-
except:
|
290 |
-
pass
|
291 |
-
del self.tracked[epoch]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fastpitch/data_function (copy).py.txt
DELETED
@@ -1,425 +0,0 @@
|
|
1 |
-
# *****************************************************************************
|
2 |
-
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
-
#
|
4 |
-
# Redistribution and use in source and binary forms, with or without
|
5 |
-
# modification, are permitted provided that the following conditions are met:
|
6 |
-
# * Redistributions of source code must retain the above copyright
|
7 |
-
# notice, this list of conditions and the following disclaimer.
|
8 |
-
# * Redistributions in binary form must reproduce the above copyright
|
9 |
-
# notice, this list of conditions and the following disclaimer in the
|
10 |
-
# documentation and/or other materials provided with the distribution.
|
11 |
-
# * Neither the name of the NVIDIA CORPORATION nor the
|
12 |
-
# names of its contributors may be used to endorse or promote products
|
13 |
-
# derived from this software without specific prior written permission.
|
14 |
-
#
|
15 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
16 |
-
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
17 |
-
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
-
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
19 |
-
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
-
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
-
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
22 |
-
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
-
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
-
#
|
26 |
-
# *****************************************************************************
|
27 |
-
|
28 |
-
import functools
|
29 |
-
import json
|
30 |
-
import re
|
31 |
-
from pathlib import Path
|
32 |
-
|
33 |
-
import librosa
|
34 |
-
import numpy as np
|
35 |
-
import torch
|
36 |
-
import torch.nn.functional as F
|
37 |
-
from scipy import ndimage
|
38 |
-
from scipy.stats import betabinom
|
39 |
-
|
40 |
-
import common.layers as layers
|
41 |
-
from common.text.text_processing import TextProcessing
|
42 |
-
from common.utils import load_wav_to_torch, load_filepaths_and_text, to_gpu
|
43 |
-
|
44 |
-
|
45 |
-
class BetaBinomialInterpolator:
|
46 |
-
"""Interpolates alignment prior matrices to save computation.
|
47 |
-
|
48 |
-
Calculating beta-binomial priors is costly. Instead cache popular sizes
|
49 |
-
and use img interpolation to get priors faster.
|
50 |
-
"""
|
51 |
-
def __init__(self, round_mel_len_to=100, round_text_len_to=20):
|
52 |
-
self.round_mel_len_to = round_mel_len_to
|
53 |
-
self.round_text_len_to = round_text_len_to
|
54 |
-
self.bank = functools.lru_cache(beta_binomial_prior_distribution)
|
55 |
-
|
56 |
-
def round(self, val, to):
|
57 |
-
return max(1, int(np.round((val + 1) / to))) * to
|
58 |
-
|
59 |
-
def __call__(self, w, h):
|
60 |
-
bw = self.round(w, to=self.round_mel_len_to)
|
61 |
-
bh = self.round(h, to=self.round_text_len_to)
|
62 |
-
ret = ndimage.zoom(self.bank(bw, bh).T, zoom=(w / bw, h / bh), order=1)
|
63 |
-
assert ret.shape[0] == w, ret.shape
|
64 |
-
assert ret.shape[1] == h, ret.shape
|
65 |
-
return ret
|
66 |
-
|
67 |
-
|
68 |
-
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling=1.0):
|
69 |
-
P = phoneme_count
|
70 |
-
M = mel_count
|
71 |
-
x = np.arange(0, P)
|
72 |
-
mel_text_probs = []
|
73 |
-
for i in range(1, M+1):
|
74 |
-
a, b = scaling * i, scaling * (M + 1 - i)
|
75 |
-
rv = betabinom(P, a, b)
|
76 |
-
mel_i_prob = rv.pmf(x)
|
77 |
-
mel_text_probs.append(mel_i_prob)
|
78 |
-
return torch.tensor(np.array(mel_text_probs))
|
79 |
-
|
80 |
-
|
81 |
-
def estimate_pitch(wav, mel_len, method='pyin', normalize_mean=None,
|
82 |
-
normalize_std=None, n_formants=1):
|
83 |
-
|
84 |
-
if type(normalize_mean) is float or type(normalize_mean) is list:
|
85 |
-
normalize_mean = torch.tensor(normalize_mean)
|
86 |
-
|
87 |
-
if type(normalize_std) is float or type(normalize_std) is list:
|
88 |
-
normalize_std = torch.tensor(normalize_std)
|
89 |
-
|
90 |
-
if method == 'pyin':
|
91 |
-
|
92 |
-
snd, sr = librosa.load(wav)
|
93 |
-
pitch_mel, voiced_flag, voiced_probs = librosa.pyin(
|
94 |
-
snd, fmin=librosa.note_to_hz('C2'),
|
95 |
-
# fmax=librosa.note_to_hz('C7'), frame_length=1024)
|
96 |
-
fmax=400, frame_length=1024)
|
97 |
-
assert np.abs(mel_len - pitch_mel.shape[0]) <= 1.0
|
98 |
-
|
99 |
-
pitch_mel = np.where(np.isnan(pitch_mel), 0.0, pitch_mel)
|
100 |
-
pitch_mel = torch.from_numpy(pitch_mel).unsqueeze(0)
|
101 |
-
pitch_mel = F.pad(pitch_mel, (0, mel_len - pitch_mel.size(1)))
|
102 |
-
|
103 |
-
if n_formants > 1:
|
104 |
-
raise NotImplementedError
|
105 |
-
|
106 |
-
else:
|
107 |
-
raise ValueError
|
108 |
-
|
109 |
-
pitch_mel = pitch_mel.float()
|
110 |
-
|
111 |
-
if normalize_mean is not None:
|
112 |
-
assert normalize_std is not None
|
113 |
-
pitch_mel = normalize_pitch(pitch_mel, normalize_mean, normalize_std)
|
114 |
-
|
115 |
-
return pitch_mel
|
116 |
-
|
117 |
-
|
118 |
-
def normalize_pitch(pitch, mean, std):
|
119 |
-
zeros = (pitch == 0.0)
|
120 |
-
pitch -= mean[:, None]
|
121 |
-
pitch /= std[:, None]
|
122 |
-
pitch[zeros] = 0.0
|
123 |
-
return pitch
|
124 |
-
|
125 |
-
|
126 |
-
class TTSDataset(torch.utils.data.Dataset):
|
127 |
-
"""
|
128 |
-
1) loads audio,text pairs
|
129 |
-
2) normalizes text and converts them to sequences of one-hot vectors
|
130 |
-
3) computes mel-spectrograms from audio files.
|
131 |
-
"""
|
132 |
-
def __init__(self,
|
133 |
-
dataset_path,
|
134 |
-
audiopaths_and_text,
|
135 |
-
text_cleaners,
|
136 |
-
n_mel_channels,
|
137 |
-
symbol_set='english_basic',
|
138 |
-
p_arpabet=1.0,
|
139 |
-
n_speakers=1,
|
140 |
-
load_mel_from_disk=True,
|
141 |
-
load_pitch_from_disk=True,
|
142 |
-
pitch_mean=214.72203, # LJSpeech defaults
|
143 |
-
pitch_std=65.72038,
|
144 |
-
max_wav_value=None,
|
145 |
-
sampling_rate=None,
|
146 |
-
filter_length=None,
|
147 |
-
hop_length=None,
|
148 |
-
win_length=None,
|
149 |
-
mel_fmin=None,
|
150 |
-
mel_fmax=None,
|
151 |
-
prepend_space_to_text=False,
|
152 |
-
append_space_to_text=False,
|
153 |
-
pitch_online_dir=None,
|
154 |
-
betabinomial_online_dir=None,
|
155 |
-
use_betabinomial_interpolator=True,
|
156 |
-
pitch_online_method='pyin',
|
157 |
-
**ignored):
|
158 |
-
|
159 |
-
# Expect a list of filenames
|
160 |
-
if type(audiopaths_and_text) is str:
|
161 |
-
audiopaths_and_text = [audiopaths_and_text]
|
162 |
-
|
163 |
-
self.dataset_path = dataset_path
|
164 |
-
self.audiopaths_and_text = load_filepaths_and_text(
|
165 |
-
dataset_path, audiopaths_and_text,
|
166 |
-
has_speakers=(n_speakers > 1))
|
167 |
-
self.load_mel_from_disk = load_mel_from_disk
|
168 |
-
if not load_mel_from_disk:
|
169 |
-
self.max_wav_value = max_wav_value
|
170 |
-
self.sampling_rate = sampling_rate
|
171 |
-
self.stft = layers.TacotronSTFT(
|
172 |
-
filter_length, hop_length, win_length,
|
173 |
-
n_mel_channels, sampling_rate, mel_fmin, mel_fmax)
|
174 |
-
self.load_pitch_from_disk = load_pitch_from_disk
|
175 |
-
|
176 |
-
self.prepend_space_to_text = prepend_space_to_text
|
177 |
-
self.append_space_to_text = append_space_to_text
|
178 |
-
|
179 |
-
assert p_arpabet == 0.0 or p_arpabet == 1.0, (
|
180 |
-
'Only 0.0 and 1.0 p_arpabet is currently supported. '
|
181 |
-
'Variable probability breaks caching of betabinomial matrices.')
|
182 |
-
|
183 |
-
self.tp = TextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet)
|
184 |
-
self.n_speakers = n_speakers
|
185 |
-
self.pitch_tmp_dir = pitch_online_dir
|
186 |
-
self.f0_method = pitch_online_method
|
187 |
-
self.betabinomial_tmp_dir = betabinomial_online_dir
|
188 |
-
self.use_betabinomial_interpolator = use_betabinomial_interpolator
|
189 |
-
|
190 |
-
if use_betabinomial_interpolator:
|
191 |
-
self.betabinomial_interpolator = BetaBinomialInterpolator()
|
192 |
-
|
193 |
-
expected_columns = (2 + int(load_pitch_from_disk) + (n_speakers > 1))
|
194 |
-
|
195 |
-
assert not (load_pitch_from_disk and self.pitch_tmp_dir is not None)
|
196 |
-
|
197 |
-
if len(self.audiopaths_and_text[0]) < expected_columns:
|
198 |
-
raise ValueError(f'Expected {expected_columns} columns in audiopaths file. '
|
199 |
-
'The format is <mel_or_wav>|[<pitch>|]<text>[|<speaker_id>]')
|
200 |
-
|
201 |
-
if len(self.audiopaths_and_text[0]) > expected_columns:
|
202 |
-
print('WARNING: Audiopaths file has more columns than expected')
|
203 |
-
|
204 |
-
to_tensor = lambda x: torch.Tensor([x]) if type(x) is float else x
|
205 |
-
self.pitch_mean = to_tensor(pitch_mean)
|
206 |
-
self.pitch_std = to_tensor(pitch_std)
|
207 |
-
|
208 |
-
def __getitem__(self, index):
|
209 |
-
# Separate filename and text
|
210 |
-
if self.n_speakers > 1:
|
211 |
-
audiopath, *extra, text, speaker = self.audiopaths_and_text[index]
|
212 |
-
speaker = int(speaker)
|
213 |
-
else:
|
214 |
-
audiopath, *extra, text = self.audiopaths_and_text[index]
|
215 |
-
speaker = None
|
216 |
-
|
217 |
-
mel = self.get_mel(audiopath)
|
218 |
-
text = self.get_text(text)
|
219 |
-
# print(text)
|
220 |
-
pitch = self.get_pitch(index, mel.size(-1))
|
221 |
-
energy = torch.norm(mel.float(), dim=0, p=2)
|
222 |
-
attn_prior = self.get_prior(index, mel.shape[1], text.shape[0])
|
223 |
-
|
224 |
-
assert pitch.size(-1) == mel.size(-1)
|
225 |
-
|
226 |
-
# No higher formants?
|
227 |
-
if len(pitch.size()) == 1:
|
228 |
-
pitch = pitch[None, :]
|
229 |
-
|
230 |
-
|
231 |
-
return (text, mel, len(text), pitch, energy, speaker, attn_prior,
|
232 |
-
audiopath)
|
233 |
-
|
234 |
-
def __len__(self):
|
235 |
-
return len(self.audiopaths_and_text)
|
236 |
-
|
237 |
-
def get_mel(self, filename):
|
238 |
-
if not self.load_mel_from_disk:
|
239 |
-
audio, sampling_rate = load_wav_to_torch(filename)
|
240 |
-
if sampling_rate != self.stft.sampling_rate:
|
241 |
-
raise ValueError("{} SR doesn't match target {} SR".format(
|
242 |
-
sampling_rate, self.stft.sampling_rate))
|
243 |
-
audio_norm = audio / self.max_wav_value
|
244 |
-
audio_norm = audio_norm.unsqueeze(0)
|
245 |
-
audio_norm = torch.autograd.Variable(audio_norm,
|
246 |
-
requires_grad=False)
|
247 |
-
melspec = self.stft.mel_spectrogram(audio_norm)
|
248 |
-
melspec = torch.squeeze(melspec, 0)
|
249 |
-
else:
|
250 |
-
melspec = torch.load(filename)
|
251 |
-
assert melspec.size(0) == self.stft.n_mel_channels, (
|
252 |
-
'Mel dimension mismatch: given {}, expected {}'.format(
|
253 |
-
melspec.size(0), self.stft.n_mel_channels))
|
254 |
-
|
255 |
-
################ Plotting mels ########################################
|
256 |
-
import matplotlib.pyplot as plt
|
257 |
-
# plt.imshow(melspec.detach().cpu().T,aspect="auto")
|
258 |
-
fig, ax1 = plt.subplots(ncols=1)
|
259 |
-
pos = ax1.imshow(melspec.cpu().numpy().T,aspect="auto")
|
260 |
-
fig.colorbar(pos, ax=ax1)
|
261 |
-
plt.show()
|
262 |
-
#######################################################################
|
263 |
-
|
264 |
-
return melspec
|
265 |
-
|
266 |
-
def get_text(self, text):
|
267 |
-
text = self.tp.encode_text(text)
|
268 |
-
space = [self.tp.encode_text("A A")[1]]
|
269 |
-
|
270 |
-
if self.prepend_space_to_text:
|
271 |
-
text = space + text
|
272 |
-
|
273 |
-
if self.append_space_to_text:
|
274 |
-
text = text + space
|
275 |
-
|
276 |
-
return torch.LongTensor(text)
|
277 |
-
|
278 |
-
def get_prior(self, index, mel_len, text_len):
|
279 |
-
|
280 |
-
if self.use_betabinomial_interpolator:
|
281 |
-
return torch.from_numpy(self.betabinomial_interpolator(mel_len,
|
282 |
-
text_len))
|
283 |
-
|
284 |
-
if self.betabinomial_tmp_dir is not None:
|
285 |
-
audiopath, *_ = self.audiopaths_and_text[index]
|
286 |
-
fname = Path(audiopath).relative_to(self.dataset_path)
|
287 |
-
fname = fname.with_suffix('.pt')
|
288 |
-
cached_fpath = Path(self.betabinomial_tmp_dir, fname)
|
289 |
-
|
290 |
-
if cached_fpath.is_file():
|
291 |
-
return torch.load(cached_fpath)
|
292 |
-
|
293 |
-
attn_prior = beta_binomial_prior_distribution(text_len, mel_len)
|
294 |
-
|
295 |
-
if self.betabinomial_tmp_dir is not None:
|
296 |
-
cached_fpath.parent.mkdir(parents=True, exist_ok=True)
|
297 |
-
torch.save(attn_prior, cached_fpath)
|
298 |
-
|
299 |
-
return attn_prior
|
300 |
-
|
301 |
-
def get_pitch(self, index, mel_len=None):
|
302 |
-
audiopath, *fields = self.audiopaths_and_text[index]
|
303 |
-
|
304 |
-
if self.n_speakers > 1:
|
305 |
-
spk = int(fields[-1])
|
306 |
-
else:
|
307 |
-
spk = 0
|
308 |
-
|
309 |
-
if self.load_pitch_from_disk:
|
310 |
-
pitchpath = fields[0]
|
311 |
-
pitch = torch.load(pitchpath)
|
312 |
-
if self.pitch_mean is not None:
|
313 |
-
assert self.pitch_std is not None
|
314 |
-
pitch = normalize_pitch(pitch, self.pitch_mean, self.pitch_std)
|
315 |
-
return pitch
|
316 |
-
|
317 |
-
if self.pitch_tmp_dir is not None:
|
318 |
-
fname = Path(audiopath).relative_to(self.dataset_path)
|
319 |
-
fname_method = fname.with_suffix('.pt')
|
320 |
-
cached_fpath = Path(self.pitch_tmp_dir, fname_method)
|
321 |
-
if cached_fpath.is_file():
|
322 |
-
return torch.load(cached_fpath)
|
323 |
-
|
324 |
-
# No luck so far - calculate
|
325 |
-
wav = audiopath
|
326 |
-
if not wav.endswith('.wav'):
|
327 |
-
wav = re.sub('/mels/', '/wavs/', wav)
|
328 |
-
wav = re.sub('.pt$', '.wav', wav)
|
329 |
-
|
330 |
-
pitch_mel = estimate_pitch(wav, mel_len, self.f0_method,
|
331 |
-
self.pitch_mean, self.pitch_std)
|
332 |
-
|
333 |
-
if self.pitch_tmp_dir is not None and not cached_fpath.is_file():
|
334 |
-
cached_fpath.parent.mkdir(parents=True, exist_ok=True)
|
335 |
-
torch.save(pitch_mel, cached_fpath)
|
336 |
-
|
337 |
-
return pitch_mel
|
338 |
-
|
339 |
-
|
340 |
-
class TTSCollate:
|
341 |
-
"""Zero-pads model inputs and targets based on number of frames per step"""
|
342 |
-
|
343 |
-
def __call__(self, batch):
|
344 |
-
"""Collate training batch from normalized text and mel-spec"""
|
345 |
-
# Right zero-pad all one-hot text sequences to max input length
|
346 |
-
input_lengths, ids_sorted_decreasing = torch.sort(
|
347 |
-
torch.LongTensor([len(x[0]) for x in batch]),
|
348 |
-
dim=0, descending=True)
|
349 |
-
max_input_len = input_lengths[0]
|
350 |
-
|
351 |
-
text_padded = torch.LongTensor(len(batch), max_input_len)
|
352 |
-
text_padded.zero_()
|
353 |
-
for i in range(len(ids_sorted_decreasing)):
|
354 |
-
text = batch[ids_sorted_decreasing[i]][0]
|
355 |
-
text_padded[i, :text.size(0)] = text
|
356 |
-
|
357 |
-
# Right zero-pad mel-spec
|
358 |
-
num_mels = batch[0][1].size(0)
|
359 |
-
max_target_len = max([x[1].size(1) for x in batch])
|
360 |
-
|
361 |
-
# Include mel padded and gate padded
|
362 |
-
mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
|
363 |
-
mel_padded.zero_()
|
364 |
-
output_lengths = torch.LongTensor(len(batch))
|
365 |
-
for i in range(len(ids_sorted_decreasing)):
|
366 |
-
mel = batch[ids_sorted_decreasing[i]][1]
|
367 |
-
mel_padded[i, :, :mel.size(1)] = mel
|
368 |
-
output_lengths[i] = mel.size(1)
|
369 |
-
|
370 |
-
n_formants = batch[0][3].shape[0]
|
371 |
-
pitch_padded = torch.zeros(mel_padded.size(0), n_formants,
|
372 |
-
mel_padded.size(2), dtype=batch[0][3].dtype)
|
373 |
-
energy_padded = torch.zeros_like(pitch_padded[:, 0, :])
|
374 |
-
|
375 |
-
for i in range(len(ids_sorted_decreasing)):
|
376 |
-
pitch = batch[ids_sorted_decreasing[i]][3]
|
377 |
-
energy = batch[ids_sorted_decreasing[i]][4]
|
378 |
-
pitch_padded[i, :, :pitch.shape[1]] = pitch
|
379 |
-
energy_padded[i, :energy.shape[0]] = energy
|
380 |
-
|
381 |
-
if batch[0][5] is not None:
|
382 |
-
speaker = torch.zeros_like(input_lengths)
|
383 |
-
for i in range(len(ids_sorted_decreasing)):
|
384 |
-
speaker[i] = batch[ids_sorted_decreasing[i]][5]
|
385 |
-
else:
|
386 |
-
speaker = None
|
387 |
-
|
388 |
-
attn_prior_padded = torch.zeros(len(batch), max_target_len,
|
389 |
-
max_input_len)
|
390 |
-
attn_prior_padded.zero_()
|
391 |
-
for i in range(len(ids_sorted_decreasing)):
|
392 |
-
prior = batch[ids_sorted_decreasing[i]][6]
|
393 |
-
attn_prior_padded[i, :prior.size(0), :prior.size(1)] = prior
|
394 |
-
|
395 |
-
# Count number of items - characters in text
|
396 |
-
len_x = [x[2] for x in batch]
|
397 |
-
len_x = torch.Tensor(len_x)
|
398 |
-
|
399 |
-
audiopaths = [batch[i][7] for i in ids_sorted_decreasing]
|
400 |
-
|
401 |
-
return (text_padded, input_lengths, mel_padded, output_lengths, len_x,
|
402 |
-
pitch_padded, energy_padded, speaker, attn_prior_padded,
|
403 |
-
audiopaths)
|
404 |
-
|
405 |
-
|
406 |
-
def batch_to_gpu(batch):
|
407 |
-
(text_padded, input_lengths, mel_padded, output_lengths, len_x,
|
408 |
-
pitch_padded, energy_padded, speaker, attn_prior, audiopaths) = batch
|
409 |
-
|
410 |
-
text_padded = to_gpu(text_padded).long()
|
411 |
-
input_lengths = to_gpu(input_lengths).long()
|
412 |
-
mel_padded = to_gpu(mel_padded).float()
|
413 |
-
output_lengths = to_gpu(output_lengths).long()
|
414 |
-
pitch_padded = to_gpu(pitch_padded).float()
|
415 |
-
energy_padded = to_gpu(energy_padded).float()
|
416 |
-
attn_prior = to_gpu(attn_prior).float()
|
417 |
-
if speaker is not None:
|
418 |
-
speaker = to_gpu(speaker).long()
|
419 |
-
|
420 |
-
# Alignments act as both inputs and targets - pass shallow copies
|
421 |
-
x = [text_padded, input_lengths, mel_padded, output_lengths,
|
422 |
-
pitch_padded, energy_padded, speaker, attn_prior, audiopaths]
|
423 |
-
y = [mel_padded, input_lengths, output_lengths]
|
424 |
-
len_x = torch.sum(output_lengths)
|
425 |
-
return (x, y, len_x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fastpitch/data_function_model_py.zip
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:7977160d12775529ba3426181093c8bca7927e52024f6d4faa91e1a0e53ef008
|
3 |
-
size 9564
|
|
|
|
|
|
|
|
fastpitch/utils_trainplot_transformers.zip
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:5d3b72d713798a0552c0939fded67cb89655e12cc406a6fe793fcc7c9f63456a
|
3 |
-
size 16365
|
|
|
|
|
|
|
|
fastpitch/utils_trainplot_transformers/train_1_with_plot.py
DELETED
@@ -1,591 +0,0 @@
|
|
1 |
-
# *****************************************************************************
|
2 |
-
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
-
#
|
4 |
-
# Redistribution and use in source and binary forms, with or without
|
5 |
-
# modification, are permitted provided that the following conditions are met:
|
6 |
-
# * Redistributions of source code must retain the above copyright
|
7 |
-
# notice, this list of conditions and the following disclaimer.
|
8 |
-
# * Redistributions in binary form must reproduce the above copyright
|
9 |
-
# notice, this list of conditions and the following disclaimer in the
|
10 |
-
# documentation and/or other materials provided with the distribution.
|
11 |
-
# * Neither the name of the NVIDIA CORPORATION nor the
|
12 |
-
# names of its contributors may be used to endorse or promote products
|
13 |
-
# derived from this software without specific prior written permission.
|
14 |
-
#
|
15 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
16 |
-
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
17 |
-
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
-
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
19 |
-
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
-
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
-
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
22 |
-
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
-
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
-
#
|
26 |
-
# *****************************************************************************
|
27 |
-
|
28 |
-
import argparse
|
29 |
-
import copy
|
30 |
-
import os
|
31 |
-
import time
|
32 |
-
from collections import defaultdict, OrderedDict
|
33 |
-
from itertools import cycle
|
34 |
-
|
35 |
-
import numpy as np
|
36 |
-
import torch
|
37 |
-
import torch.distributed as dist
|
38 |
-
import amp_C
|
39 |
-
from apex.optimizers import FusedAdam, FusedLAMB
|
40 |
-
from torch.nn.parallel import DistributedDataParallel
|
41 |
-
from torch.utils.data import DataLoader
|
42 |
-
from torch.utils.data.distributed import DistributedSampler
|
43 |
-
|
44 |
-
import common.tb_dllogger as logger
|
45 |
-
import models
|
46 |
-
from common.tb_dllogger import log
|
47 |
-
from common.repeated_dataloader import (RepeatedDataLoader,
|
48 |
-
RepeatedDistributedSampler)
|
49 |
-
from common.text import cmudict
|
50 |
-
from common.utils import BenchmarkStats, Checkpointer, prepare_tmp
|
51 |
-
from fastpitch.attn_loss_function import AttentionBinarizationLoss
|
52 |
-
from fastpitch.data_function import batch_to_gpu, TTSCollate, TTSDataset
|
53 |
-
from fastpitch.loss_function import FastPitchLoss
|
54 |
-
|
55 |
-
import matplotlib.pyplot as plt
|
56 |
-
|
57 |
-
def parse_args(parser):
|
58 |
-
parser.add_argument('-o', '--output', type=str, required=True,
|
59 |
-
help='Directory to save checkpoints')
|
60 |
-
parser.add_argument('-d', '--dataset-path', type=str, default='./',
|
61 |
-
help='Path to dataset')
|
62 |
-
parser.add_argument('--log-file', type=str, default=None,
|
63 |
-
help='Path to a DLLogger log file')
|
64 |
-
|
65 |
-
train = parser.add_argument_group('training setup')
|
66 |
-
train.add_argument('--epochs', type=int, required=True,
|
67 |
-
help='Number of total epochs to run')
|
68 |
-
train.add_argument('--epochs-per-checkpoint', type=int, default=50,
|
69 |
-
help='Number of epochs per checkpoint')
|
70 |
-
train.add_argument('--checkpoint-path', type=str, default=None,
|
71 |
-
help='Checkpoint path to resume training')
|
72 |
-
train.add_argument('--keep-milestones', default=list(range(100, 1000, 100)),
|
73 |
-
type=int, nargs='+',
|
74 |
-
help='Milestone checkpoints to keep from removing')
|
75 |
-
train.add_argument('--resume', action='store_true',
|
76 |
-
help='Resume training from the last checkpoint')
|
77 |
-
train.add_argument('--seed', type=int, default=1234,
|
78 |
-
help='Seed for PyTorch random number generators')
|
79 |
-
train.add_argument('--amp', action='store_true',
|
80 |
-
help='Enable AMP')
|
81 |
-
train.add_argument('--cuda', action='store_true',
|
82 |
-
help='Run on GPU using CUDA')
|
83 |
-
train.add_argument('--cudnn-benchmark', action='store_true',
|
84 |
-
help='Enable cudnn benchmark mode')
|
85 |
-
train.add_argument('--ema-decay', type=float, default=0,
|
86 |
-
help='Discounting factor for training weights EMA')
|
87 |
-
train.add_argument('--grad-accumulation', type=int, default=1,
|
88 |
-
help='Training steps to accumulate gradients for')
|
89 |
-
train.add_argument('--kl-loss-start-epoch', type=int, default=250,
|
90 |
-
help='Start adding the hard attention loss term')
|
91 |
-
train.add_argument('--kl-loss-warmup-epochs', type=int, default=100,
|
92 |
-
help='Gradually increase the hard attention loss term')
|
93 |
-
train.add_argument('--kl-loss-weight', type=float, default=1.0,
|
94 |
-
help='Gradually increase the hard attention loss term')
|
95 |
-
train.add_argument('--benchmark-epochs-num', type=int, default=20,
|
96 |
-
help='Number of epochs for calculating final stats')
|
97 |
-
train.add_argument('--validation-freq', type=int, default=1,
|
98 |
-
help='Validate every N epochs to use less compute')
|
99 |
-
|
100 |
-
opt = parser.add_argument_group('optimization setup')
|
101 |
-
opt.add_argument('--optimizer', type=str, default='lamb',
|
102 |
-
help='Optimization algorithm')
|
103 |
-
opt.add_argument('-lr', '--learning-rate', type=float, required=True,
|
104 |
-
help='Learing rate')
|
105 |
-
opt.add_argument('--weight-decay', default=1e-6, type=float,
|
106 |
-
help='Weight decay')
|
107 |
-
opt.add_argument('--grad-clip-thresh', default=1000.0, type=float,
|
108 |
-
help='Clip threshold for gradients')
|
109 |
-
opt.add_argument('-bs', '--batch-size', type=int, required=True,
|
110 |
-
help='Batch size per GPU')
|
111 |
-
opt.add_argument('--warmup-steps', type=int, default=1000,
|
112 |
-
help='Number of steps for lr warmup')
|
113 |
-
opt.add_argument('--dur-predictor-loss-scale', type=float,
|
114 |
-
default=1.0, help='Rescale duration predictor loss')
|
115 |
-
opt.add_argument('--pitch-predictor-loss-scale', type=float,
|
116 |
-
default=1.0, help='Rescale pitch predictor loss')
|
117 |
-
opt.add_argument('--attn-loss-scale', type=float,
|
118 |
-
default=1.0, help='Rescale alignment loss')
|
119 |
-
|
120 |
-
data = parser.add_argument_group('dataset parameters')
|
121 |
-
data.add_argument('--training-files', type=str, nargs='*', required=True,
|
122 |
-
help='Paths to training filelists.')
|
123 |
-
data.add_argument('--validation-files', type=str, nargs='*',
|
124 |
-
required=True, help='Paths to validation filelists')
|
125 |
-
data.add_argument('--text-cleaners', nargs='*',
|
126 |
-
default=['english_cleaners'], type=str,
|
127 |
-
help='Type of text cleaners for input text')
|
128 |
-
data.add_argument('--symbol-set', type=str, default='english_basic',
|
129 |
-
help='Define symbol set for input text')
|
130 |
-
data.add_argument('--p-arpabet', type=float, default=0.0,
|
131 |
-
help='Probability of using arpabets instead of graphemes '
|
132 |
-
'for each word; set 0 for pure grapheme training')
|
133 |
-
data.add_argument('--heteronyms-path', type=str, default='cmudict/heteronyms',
|
134 |
-
help='Path to the list of heteronyms')
|
135 |
-
data.add_argument('--cmudict-path', type=str, default='cmudict/cmudict-0.7b',
|
136 |
-
help='Path to the pronouncing dictionary')
|
137 |
-
data.add_argument('--prepend-space-to-text', action='store_true',
|
138 |
-
help='Capture leading silence with a space token')
|
139 |
-
data.add_argument('--append-space-to-text', action='store_true',
|
140 |
-
help='Capture trailing silence with a space token')
|
141 |
-
data.add_argument('--num-workers', type=int, default=2, # 6
|
142 |
-
help='Subprocesses for train and val DataLoaders')
|
143 |
-
data.add_argument('--trainloader-repeats', type=int, default=100,
|
144 |
-
help='Repeats the dataset to prolong epochs')
|
145 |
-
|
146 |
-
cond = parser.add_argument_group('data for conditioning')
|
147 |
-
cond.add_argument('--n-speakers', type=int, default=1,
|
148 |
-
help='Number of speakers in the dataset. '
|
149 |
-
'n_speakers > 1 enables speaker embeddings')
|
150 |
-
cond.add_argument('--load-pitch-from-disk', action='store_true',
|
151 |
-
help='Use pitch cached on disk with prepare_dataset.py')
|
152 |
-
cond.add_argument('--pitch-online-method', default='pyin',
|
153 |
-
choices=['pyin'],
|
154 |
-
help='Calculate pitch on the fly during trainig')
|
155 |
-
cond.add_argument('--pitch-online-dir', type=str, default=None,
|
156 |
-
help='A directory for storing pitch calculated on-line')
|
157 |
-
cond.add_argument('--pitch-mean', type=float, default=125.626816, #default=214.72203,
|
158 |
-
help='Normalization value for pitch')
|
159 |
-
cond.add_argument('--pitch-std', type=float, default=37.52, #default=65.72038,
|
160 |
-
help='Normalization value for pitch')
|
161 |
-
cond.add_argument('--load-mel-from-disk', action='store_true',
|
162 |
-
help='Use mel-spectrograms cache on the disk') # XXX
|
163 |
-
|
164 |
-
audio = parser.add_argument_group('audio parameters')
|
165 |
-
audio.add_argument('--max-wav-value', default=32768.0, type=float,
|
166 |
-
help='Maximum audiowave value')
|
167 |
-
audio.add_argument('--sampling-rate', default=22050, type=int,
|
168 |
-
help='Sampling rate')
|
169 |
-
audio.add_argument('--filter-length', default=1024, type=int,
|
170 |
-
help='Filter length')
|
171 |
-
audio.add_argument('--hop-length', default=256, type=int,
|
172 |
-
help='Hop (stride) length')
|
173 |
-
audio.add_argument('--win-length', default=1024, type=int,
|
174 |
-
help='Window length')
|
175 |
-
audio.add_argument('--mel-fmin', default=0.0, type=float,
|
176 |
-
help='Minimum mel frequency')
|
177 |
-
audio.add_argument('--mel-fmax', default=8000.0, type=float,
|
178 |
-
help='Maximum mel frequency')
|
179 |
-
|
180 |
-
dist = parser.add_argument_group('distributed setup')
|
181 |
-
dist.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0),
|
182 |
-
help='Rank of the process for multiproc; do not set manually')
|
183 |
-
dist.add_argument('--world_size', type=int, default=os.getenv('WORLD_SIZE', 1),
|
184 |
-
help='Number of processes for multiproc; do not set manually')
|
185 |
-
return parser
|
186 |
-
|
187 |
-
|
188 |
-
def reduce_tensor(tensor, num_gpus):
|
189 |
-
rt = tensor.clone()
|
190 |
-
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
191 |
-
return rt.true_divide(num_gpus)
|
192 |
-
|
193 |
-
|
194 |
-
def init_distributed(args, world_size, rank):
|
195 |
-
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
|
196 |
-
print("Initializing distributed training")
|
197 |
-
|
198 |
-
# Set cuda device so everything is done on the right GPU.
|
199 |
-
torch.cuda.set_device(rank % torch.cuda.device_count())
|
200 |
-
|
201 |
-
# Initialize distributed communication
|
202 |
-
dist.init_process_group(backend=('nccl' if args.cuda else 'gloo'),
|
203 |
-
init_method='env://')
|
204 |
-
print("Done initializing distributed training")
|
205 |
-
|
206 |
-
|
207 |
-
def validate(model, epoch, total_iter, criterion, val_loader, distributed_run,
|
208 |
-
batch_to_gpu, local_rank, ema=False):
|
209 |
-
was_training = model.training
|
210 |
-
model.eval()
|
211 |
-
|
212 |
-
tik = time.perf_counter()
|
213 |
-
with torch.no_grad():
|
214 |
-
val_meta = defaultdict(float)
|
215 |
-
val_num_frames = 0
|
216 |
-
for i, batch in enumerate(val_loader):
|
217 |
-
x, y, num_frames = batch_to_gpu(batch)
|
218 |
-
y_pred = model(x)
|
219 |
-
loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum')
|
220 |
-
|
221 |
-
if distributed_run:
|
222 |
-
for k, v in meta.items():
|
223 |
-
val_meta[k] += reduce_tensor(v, 1)
|
224 |
-
val_num_frames += reduce_tensor(num_frames.data, 1).item()
|
225 |
-
else:
|
226 |
-
for k, v in meta.items():
|
227 |
-
val_meta[k] += v
|
228 |
-
val_num_frames += num_frames.item()
|
229 |
-
|
230 |
-
# NOTE: ugly patch to visualize the first utterance of the validation corpus.
|
231 |
-
# The goal is to determine if the training is progressing properly
|
232 |
-
if (i == 0) and (local_rank == 0) and (not ema):
|
233 |
-
# Plot some debug information
|
234 |
-
fig, axs = plt.subplots(2, 2, figsize=(21,14))
|
235 |
-
|
236 |
-
# - Mel-spectrogram
|
237 |
-
pred_mel = y_pred[0][0, :, :].cpu().detach().numpy().astype(np.float32).T
|
238 |
-
orig_mel = y[0][0, :, :].cpu().detach().numpy().astype(np.float32)
|
239 |
-
axs[0,0].imshow(orig_mel, aspect='auto', origin='lower', interpolation='nearest')
|
240 |
-
axs[1,0].imshow(pred_mel, aspect='auto', origin='lower', interpolation='nearest')
|
241 |
-
|
242 |
-
# Prosody
|
243 |
-
f0_pred = y_pred[4][0, :].cpu().detach().numpy().astype(np.float32)
|
244 |
-
f0_ori = y_pred[5][0, :].cpu().detach().numpy().astype(np.float32)
|
245 |
-
axs[1,1].plot(f0_ori)
|
246 |
-
axs[1,1].plot(f0_pred)
|
247 |
-
|
248 |
-
# # Duration
|
249 |
-
# att_pred = y_pred[2][0, :].cpu().detach().numpy().astype(np.float32)
|
250 |
-
# att_ori = x[7][0,:].cpu().detach().numpy().astype(np.float32)
|
251 |
-
# axs[0,1].imshow(att_ori, aspect='auto', origin='lower', interpolation='nearest')
|
252 |
-
|
253 |
-
if not os.path.exists("debug_epoch/"):
|
254 |
-
os.makedirs("debug_epoch_laila/")
|
255 |
-
|
256 |
-
fig.savefig(f'debug_epoch/{epoch:06d}.png', bbox_inches='tight')
|
257 |
-
|
258 |
-
val_meta = {k: v / len(val_loader.dataset) for k, v in val_meta.items()}
|
259 |
-
|
260 |
-
val_meta['took'] = time.perf_counter() - tik
|
261 |
-
|
262 |
-
log((epoch,) if epoch is not None else (), tb_total_steps=total_iter,
|
263 |
-
subset='val_ema' if ema else 'val',
|
264 |
-
data=OrderedDict([
|
265 |
-
('loss', val_meta['loss'].item()),
|
266 |
-
('mel_loss', val_meta['mel_loss'].item()),
|
267 |
-
('frames/s', val_num_frames / val_meta['took']),
|
268 |
-
('took', val_meta['took'])]),
|
269 |
-
)
|
270 |
-
|
271 |
-
if was_training:
|
272 |
-
model.train()
|
273 |
-
return val_meta
|
274 |
-
|
275 |
-
|
276 |
-
def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None):
|
277 |
-
if warmup_iters == 0:
|
278 |
-
scale = 1.0
|
279 |
-
elif total_iter > warmup_iters:
|
280 |
-
scale = 1. / (total_iter ** 0.5)
|
281 |
-
else:
|
282 |
-
scale = total_iter / (warmup_iters ** 1.5)
|
283 |
-
|
284 |
-
for param_group in opt.param_groups:
|
285 |
-
param_group['lr'] = learning_rate * scale
|
286 |
-
|
287 |
-
|
288 |
-
def apply_ema_decay(model, ema_model, decay):
|
289 |
-
if not decay:
|
290 |
-
return
|
291 |
-
st = model.state_dict()
|
292 |
-
add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
|
293 |
-
for k, v in ema_model.state_dict().items():
|
294 |
-
if add_module and not k.startswith('module.'):
|
295 |
-
k = 'module.' + k
|
296 |
-
v.copy_(decay * v + (1 - decay) * st[k])
|
297 |
-
|
298 |
-
|
299 |
-
def init_multi_tensor_ema(model, ema_model):
|
300 |
-
model_weights = list(model.state_dict().values())
|
301 |
-
ema_model_weights = list(ema_model.state_dict().values())
|
302 |
-
ema_overflow_buf = torch.cuda.IntTensor([0])
|
303 |
-
return model_weights, ema_model_weights, ema_overflow_buf
|
304 |
-
|
305 |
-
|
306 |
-
def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
|
307 |
-
amp_C.multi_tensor_axpby(
|
308 |
-
65536, overflow_buf, [ema_weights, model_weights, ema_weights],
|
309 |
-
decay, 1-decay, -1)
|
310 |
-
|
311 |
-
|
312 |
-
def main():
|
313 |
-
parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
|
314 |
-
allow_abbrev=False)
|
315 |
-
parser = parse_args(parser)
|
316 |
-
args, _ = parser.parse_known_args()
|
317 |
-
|
318 |
-
if args.p_arpabet > 0.0:
|
319 |
-
cmudict.initialize(args.cmudict_path, args.heteronyms_path)
|
320 |
-
|
321 |
-
distributed_run = args.world_size > 1
|
322 |
-
|
323 |
-
torch.manual_seed(args.seed + args.local_rank)
|
324 |
-
np.random.seed(args.seed + args.local_rank)
|
325 |
-
|
326 |
-
if args.local_rank == 0:
|
327 |
-
if not os.path.exists(args.output):
|
328 |
-
os.makedirs(args.output)
|
329 |
-
|
330 |
-
log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
|
331 |
-
tb_subsets = ['train', 'val']
|
332 |
-
if args.ema_decay > 0.0:
|
333 |
-
tb_subsets.append('val_ema')
|
334 |
-
|
335 |
-
logger.init(log_fpath, args.output, enabled=(args.local_rank == 0),
|
336 |
-
tb_subsets=tb_subsets)
|
337 |
-
logger.parameters(vars(args), tb_subset='train')
|
338 |
-
|
339 |
-
parser = models.parse_model_args('FastPitch', parser)
|
340 |
-
args, unk_args = parser.parse_known_args()
|
341 |
-
if len(unk_args) > 0:
|
342 |
-
raise ValueError(f'Invalid options {unk_args}')
|
343 |
-
|
344 |
-
torch.backends.cudnn.benchmark = args.cudnn_benchmark
|
345 |
-
|
346 |
-
if distributed_run:
|
347 |
-
init_distributed(args, args.world_size, args.local_rank)
|
348 |
-
else:
|
349 |
-
if args.trainloader_repeats > 1:
|
350 |
-
print('WARNING: Disabled --trainloader-repeats, supported only for'
|
351 |
-
' multi-GPU data loading.')
|
352 |
-
args.trainloader_repeats = 1
|
353 |
-
|
354 |
-
device = torch.device('cuda' if args.cuda else 'cpu')
|
355 |
-
model_config = models.get_model_config('FastPitch', args)
|
356 |
-
model = models.get_model('FastPitch', model_config, device)
|
357 |
-
|
358 |
-
attention_kl_loss = AttentionBinarizationLoss()
|
359 |
-
|
360 |
-
# Store pitch mean/std as params to translate from Hz during inference
|
361 |
-
model.pitch_mean[0] = args.pitch_mean
|
362 |
-
model.pitch_std[0] = args.pitch_std
|
363 |
-
|
364 |
-
kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9,
|
365 |
-
weight_decay=args.weight_decay)
|
366 |
-
if args.optimizer == 'adam':
|
367 |
-
optimizer = FusedAdam(model.parameters(), **kw)
|
368 |
-
# optimizer = torch.optim.Adam(model.parameters(), **kw)
|
369 |
-
elif args.optimizer == 'lamb':
|
370 |
-
|
371 |
-
optimizer = FusedLAMB(model.parameters(), **kw)
|
372 |
-
# optimizer = torch.optim.Adam(model.parameters(), **kw)
|
373 |
-
else:
|
374 |
-
raise ValueError
|
375 |
-
|
376 |
-
scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
377 |
-
|
378 |
-
if args.ema_decay > 0:
|
379 |
-
ema_model = copy.deepcopy(model)
|
380 |
-
else:
|
381 |
-
ema_model = None
|
382 |
-
|
383 |
-
if distributed_run:
|
384 |
-
model = DistributedDataParallel(
|
385 |
-
model, device_ids=[args.local_rank], output_device=args.local_rank,
|
386 |
-
find_unused_parameters=True)
|
387 |
-
|
388 |
-
train_state = {'epoch': 1, 'total_iter': 1}
|
389 |
-
checkpointer = Checkpointer(args.output, args.keep_milestones)
|
390 |
-
|
391 |
-
checkpointer.maybe_load(model, optimizer, scaler, train_state, args,
|
392 |
-
ema_model)
|
393 |
-
|
394 |
-
start_epoch = train_state['epoch']
|
395 |
-
total_iter = train_state['total_iter']
|
396 |
-
|
397 |
-
criterion = FastPitchLoss(
|
398 |
-
dur_predictor_loss_scale=args.dur_predictor_loss_scale,
|
399 |
-
pitch_predictor_loss_scale=args.pitch_predictor_loss_scale,
|
400 |
-
attn_loss_scale=args.attn_loss_scale)
|
401 |
-
|
402 |
-
collate_fn = TTSCollate()
|
403 |
-
|
404 |
-
if args.local_rank == 0:
|
405 |
-
prepare_tmp(args.pitch_online_dir)
|
406 |
-
|
407 |
-
trainset = TTSDataset(audiopaths_and_text=args.training_files, **vars(args))
|
408 |
-
valset = TTSDataset(audiopaths_and_text=args.validation_files, **vars(args))
|
409 |
-
|
410 |
-
if distributed_run:
|
411 |
-
train_sampler = RepeatedDistributedSampler(args.trainloader_repeats,
|
412 |
-
trainset, drop_last=True)
|
413 |
-
val_sampler = DistributedSampler(valset)
|
414 |
-
shuffle = False
|
415 |
-
else:
|
416 |
-
train_sampler, val_sampler, shuffle = None, None, False ########### was True
|
417 |
-
|
418 |
-
# 4 workers are optimal on DGX-1 (from epoch 2 onwards)
|
419 |
-
kw = {'num_workers': args.num_workers, 'batch_size': args.batch_size,
|
420 |
-
'collate_fn': collate_fn}
|
421 |
-
train_loader = RepeatedDataLoader(args.trainloader_repeats, trainset,
|
422 |
-
shuffle=shuffle, drop_last=True,
|
423 |
-
sampler=train_sampler, pin_memory=True,
|
424 |
-
persistent_workers=True, **kw)
|
425 |
-
val_loader = DataLoader(valset, shuffle=False, sampler=val_sampler,
|
426 |
-
pin_memory=False, **kw)
|
427 |
-
if args.ema_decay:
|
428 |
-
mt_ema_params = init_multi_tensor_ema(model, ema_model)
|
429 |
-
|
430 |
-
model.train()
|
431 |
-
bmark_stats = BenchmarkStats()
|
432 |
-
|
433 |
-
torch.cuda.synchronize()
|
434 |
-
for epoch in range(start_epoch, args.epochs + 1):
|
435 |
-
epoch_start_time = time.perf_counter()
|
436 |
-
|
437 |
-
epoch_loss = 0.0
|
438 |
-
epoch_mel_loss = 0.0
|
439 |
-
epoch_num_frames = 0
|
440 |
-
epoch_frames_per_sec = 0.0
|
441 |
-
|
442 |
-
if distributed_run:
|
443 |
-
train_loader.sampler.set_epoch(epoch)
|
444 |
-
|
445 |
-
iter_loss = 0
|
446 |
-
iter_num_frames = 0
|
447 |
-
iter_meta = {}
|
448 |
-
iter_start_time = time.perf_counter()
|
449 |
-
|
450 |
-
epoch_iter = 1
|
451 |
-
for batch, accum_step in zip(train_loader,
|
452 |
-
cycle(range(1, args.grad_accumulation + 1))):
|
453 |
-
if accum_step == 1:
|
454 |
-
adjust_learning_rate(total_iter, optimizer, args.learning_rate,
|
455 |
-
args.warmup_steps)
|
456 |
-
|
457 |
-
model.zero_grad(set_to_none=True)
|
458 |
-
|
459 |
-
x, y, num_frames = batch_to_gpu(batch)
|
460 |
-
|
461 |
-
with torch.cuda.amp.autocast(enabled=args.amp):
|
462 |
-
y_pred = model(x)
|
463 |
-
loss, meta = criterion(y_pred, y)
|
464 |
-
|
465 |
-
if (args.kl_loss_start_epoch is not None
|
466 |
-
and epoch >= args.kl_loss_start_epoch):
|
467 |
-
|
468 |
-
if args.kl_loss_start_epoch == epoch and epoch_iter == 1:
|
469 |
-
print('Begin hard_attn loss')
|
470 |
-
|
471 |
-
_, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred
|
472 |
-
binarization_loss = attention_kl_loss(attn_hard, attn_soft)
|
473 |
-
kl_weight = min((epoch - args.kl_loss_start_epoch) / args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight
|
474 |
-
meta['kl_loss'] = binarization_loss.clone().detach() * kl_weight
|
475 |
-
loss += kl_weight * binarization_loss
|
476 |
-
|
477 |
-
else:
|
478 |
-
meta['kl_loss'] = torch.zeros_like(loss)
|
479 |
-
kl_weight = 0
|
480 |
-
binarization_loss = 0
|
481 |
-
|
482 |
-
loss /= args.grad_accumulation
|
483 |
-
|
484 |
-
meta = {k: v / args.grad_accumulation
|
485 |
-
for k, v in meta.items()}
|
486 |
-
|
487 |
-
if args.amp:
|
488 |
-
scaler.scale(loss).backward()
|
489 |
-
else:
|
490 |
-
loss.backward()
|
491 |
-
|
492 |
-
if distributed_run:
|
493 |
-
reduced_loss = reduce_tensor(loss.data, args.world_size).item()
|
494 |
-
reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
|
495 |
-
meta = {k: reduce_tensor(v, args.world_size) for k, v in meta.items()}
|
496 |
-
else:
|
497 |
-
reduced_loss = loss.item()
|
498 |
-
reduced_num_frames = num_frames.item()
|
499 |
-
if np.isnan(reduced_loss):
|
500 |
-
raise Exception("loss is NaN")
|
501 |
-
|
502 |
-
iter_loss += reduced_loss
|
503 |
-
iter_num_frames += reduced_num_frames
|
504 |
-
iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}
|
505 |
-
|
506 |
-
if accum_step % args.grad_accumulation == 0:
|
507 |
-
|
508 |
-
logger.log_grads_tb(total_iter, model)
|
509 |
-
if args.amp:
|
510 |
-
scaler.unscale_(optimizer)
|
511 |
-
torch.nn.utils.clip_grad_norm_(
|
512 |
-
model.parameters(), args.grad_clip_thresh)
|
513 |
-
scaler.step(optimizer)
|
514 |
-
scaler.update()
|
515 |
-
else:
|
516 |
-
torch.nn.utils.clip_grad_norm_(
|
517 |
-
model.parameters(), args.grad_clip_thresh)
|
518 |
-
optimizer.step()
|
519 |
-
|
520 |
-
if args.ema_decay > 0.0:
|
521 |
-
apply_multi_tensor_ema(args.ema_decay, *mt_ema_params)
|
522 |
-
|
523 |
-
iter_mel_loss = iter_meta['mel_loss'].item()
|
524 |
-
iter_kl_loss = iter_meta['kl_loss'].item()
|
525 |
-
iter_time = time.perf_counter() - iter_start_time
|
526 |
-
epoch_frames_per_sec += iter_num_frames / iter_time
|
527 |
-
epoch_loss += iter_loss
|
528 |
-
epoch_num_frames += iter_num_frames
|
529 |
-
epoch_mel_loss += iter_mel_loss
|
530 |
-
|
531 |
-
num_iters = len(train_loader) // args.grad_accumulation
|
532 |
-
log((epoch, epoch_iter, num_iters), tb_total_steps=total_iter,
|
533 |
-
subset='train', data=OrderedDict([
|
534 |
-
('loss', iter_loss),
|
535 |
-
('mel_loss', iter_mel_loss),
|
536 |
-
('kl_loss', iter_kl_loss),
|
537 |
-
('kl_weight', kl_weight),
|
538 |
-
('frames/s', iter_num_frames / iter_time),
|
539 |
-
('took', iter_time),
|
540 |
-
('lrate', optimizer.param_groups[0]['lr'])]),
|
541 |
-
)
|
542 |
-
|
543 |
-
iter_loss = 0
|
544 |
-
iter_num_frames = 0
|
545 |
-
iter_meta = {}
|
546 |
-
iter_start_time = time.perf_counter()
|
547 |
-
|
548 |
-
if epoch_iter == num_iters:
|
549 |
-
break
|
550 |
-
epoch_iter += 1
|
551 |
-
total_iter += 1
|
552 |
-
|
553 |
-
# Finished epoch
|
554 |
-
epoch_loss /= epoch_iter
|
555 |
-
epoch_mel_loss /= epoch_iter
|
556 |
-
epoch_time = time.perf_counter() - epoch_start_time
|
557 |
-
|
558 |
-
log((epoch,), tb_total_steps=None, subset='train_avg',
|
559 |
-
data=OrderedDict([
|
560 |
-
('loss', epoch_loss),
|
561 |
-
('mel_loss', epoch_mel_loss),
|
562 |
-
('frames/s', epoch_num_frames / epoch_time),
|
563 |
-
('took', epoch_time)]),
|
564 |
-
)
|
565 |
-
bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss,
|
566 |
-
epoch_time)
|
567 |
-
|
568 |
-
if epoch % args.validation_freq == 0:
|
569 |
-
validate(model, epoch, total_iter, criterion, val_loader,
|
570 |
-
distributed_run, batch_to_gpu, ema=False, local_rank=args.local_rank)
|
571 |
-
|
572 |
-
if args.ema_decay > 0:
|
573 |
-
validate(ema_model, epoch, total_iter, criterion, val_loader,
|
574 |
-
distributed_run, batch_to_gpu, args.local_rank, ema=True)
|
575 |
-
|
576 |
-
# save before making sched.step() for proper loading of LR
|
577 |
-
checkpointer.maybe_save(args, model, ema_model, optimizer, scaler,
|
578 |
-
epoch, total_iter, model_config)
|
579 |
-
logger.flush()
|
580 |
-
|
581 |
-
# Finished training
|
582 |
-
if len(bmark_stats) > 0:
|
583 |
-
log((), tb_total_steps=None, subset='train_avg',
|
584 |
-
data=bmark_stats.get(args.benchmark_epochs_num))
|
585 |
-
|
586 |
-
validate(model, None, total_iter, criterion, val_loader, distributed_run,
|
587 |
-
batch_to_gpu)
|
588 |
-
|
589 |
-
|
590 |
-
if __name__ == '__main__':
|
591 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fastpitch/utils_trainplot_transformers/transformer.py
DELETED
@@ -1,213 +0,0 @@
|
|
1 |
-
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import torch
|
16 |
-
import torch.nn as nn
|
17 |
-
import torch.nn.functional as F
|
18 |
-
|
19 |
-
from common.utils import mask_from_lens
|
20 |
-
|
21 |
-
|
22 |
-
class PositionalEmbedding(nn.Module):
|
23 |
-
def __init__(self, demb):
|
24 |
-
super(PositionalEmbedding, self).__init__()
|
25 |
-
self.demb = demb
|
26 |
-
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
|
27 |
-
self.register_buffer('inv_freq', inv_freq)
|
28 |
-
|
29 |
-
def forward(self, pos_seq, bsz=None):
|
30 |
-
sinusoid_inp = torch.matmul(torch.unsqueeze(pos_seq, -1),
|
31 |
-
torch.unsqueeze(self.inv_freq, 0))
|
32 |
-
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
|
33 |
-
if bsz is not None:
|
34 |
-
return pos_emb[None, :, :].expand(bsz, -1, -1)
|
35 |
-
else:
|
36 |
-
return pos_emb[None, :, :]
|
37 |
-
|
38 |
-
|
39 |
-
class PositionwiseConvFF(nn.Module):
|
40 |
-
def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
|
41 |
-
super(PositionwiseConvFF, self).__init__()
|
42 |
-
|
43 |
-
self.d_model = d_model
|
44 |
-
self.d_inner = d_inner
|
45 |
-
self.dropout = dropout
|
46 |
-
|
47 |
-
self.CoreNet = nn.Sequential(
|
48 |
-
nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
|
49 |
-
nn.ReLU(),
|
50 |
-
# nn.Dropout(dropout), # worse convergence
|
51 |
-
nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
|
52 |
-
nn.Dropout(dropout),
|
53 |
-
)
|
54 |
-
self.layer_norm = nn.LayerNorm(d_model)
|
55 |
-
self.pre_lnorm = pre_lnorm
|
56 |
-
|
57 |
-
def forward(self, inp):
|
58 |
-
return self._forward(inp)
|
59 |
-
|
60 |
-
def _forward(self, inp):
|
61 |
-
if self.pre_lnorm:
|
62 |
-
# layer normalization + positionwise feed-forward
|
63 |
-
core_out = inp.transpose(1, 2)
|
64 |
-
core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype))
|
65 |
-
core_out = core_out.transpose(1, 2)
|
66 |
-
|
67 |
-
# residual connection
|
68 |
-
output = core_out + inp
|
69 |
-
else:
|
70 |
-
# positionwise feed-forward
|
71 |
-
core_out = inp.transpose(1, 2)
|
72 |
-
core_out = self.CoreNet(core_out)
|
73 |
-
core_out = core_out.transpose(1, 2)
|
74 |
-
|
75 |
-
# residual connection + layer normalization
|
76 |
-
output = self.layer_norm(inp + core_out).to(inp.dtype)
|
77 |
-
|
78 |
-
return output
|
79 |
-
|
80 |
-
|
81 |
-
class MultiHeadAttn(nn.Module):
|
82 |
-
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1,
|
83 |
-
pre_lnorm=False):
|
84 |
-
super(MultiHeadAttn, self).__init__()
|
85 |
-
|
86 |
-
self.n_head = n_head
|
87 |
-
self.d_model = d_model
|
88 |
-
self.d_head = d_head
|
89 |
-
self.scale = 1 / (d_head ** 0.5)
|
90 |
-
self.pre_lnorm = pre_lnorm
|
91 |
-
|
92 |
-
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
|
93 |
-
self.drop = nn.Dropout(dropout)
|
94 |
-
self.dropatt = nn.Dropout(dropatt)
|
95 |
-
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
96 |
-
self.layer_norm = nn.LayerNorm(d_model)
|
97 |
-
|
98 |
-
def forward(self, inp, attn_mask=None):
|
99 |
-
return self._forward(inp, attn_mask)
|
100 |
-
|
101 |
-
def _forward(self, inp, attn_mask=None):
|
102 |
-
residual = inp
|
103 |
-
|
104 |
-
if self.pre_lnorm:
|
105 |
-
# layer normalization
|
106 |
-
inp = self.layer_norm(inp)
|
107 |
-
|
108 |
-
n_head, d_head = self.n_head, self.d_head
|
109 |
-
|
110 |
-
head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)
|
111 |
-
head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
|
112 |
-
head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
|
113 |
-
head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
|
114 |
-
|
115 |
-
q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
116 |
-
k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
117 |
-
v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
118 |
-
|
119 |
-
attn_score = torch.bmm(q, k.transpose(1, 2))
|
120 |
-
attn_score.mul_(self.scale)
|
121 |
-
|
122 |
-
if attn_mask is not None:
|
123 |
-
attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype)
|
124 |
-
attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
|
125 |
-
attn_score.masked_fill_(attn_mask.to(torch.bool), -float('inf'))
|
126 |
-
|
127 |
-
attn_prob = F.softmax(attn_score, dim=2)
|
128 |
-
attn_prob = self.dropatt(attn_prob)
|
129 |
-
attn_vec = torch.bmm(attn_prob, v)
|
130 |
-
|
131 |
-
attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
|
132 |
-
attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(
|
133 |
-
inp.size(0), inp.size(1), n_head * d_head)
|
134 |
-
|
135 |
-
# linear projection
|
136 |
-
attn_out = self.o_net(attn_vec)
|
137 |
-
attn_out = self.drop(attn_out)
|
138 |
-
|
139 |
-
if self.pre_lnorm:
|
140 |
-
# residual connection
|
141 |
-
output = residual + attn_out
|
142 |
-
else:
|
143 |
-
# residual connection + layer normalization
|
144 |
-
output = self.layer_norm(residual + attn_out)
|
145 |
-
|
146 |
-
output = output.to(attn_out.dtype)
|
147 |
-
|
148 |
-
return output
|
149 |
-
|
150 |
-
|
151 |
-
class TransformerLayer(nn.Module):
|
152 |
-
def __init__(self, n_head, d_model, d_head, d_inner, kernel_size, dropout,
|
153 |
-
**kwargs):
|
154 |
-
super(TransformerLayer, self).__init__()
|
155 |
-
|
156 |
-
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
|
157 |
-
self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout,
|
158 |
-
pre_lnorm=kwargs.get('pre_lnorm'))
|
159 |
-
|
160 |
-
def forward(self, dec_inp, mask=None):
|
161 |
-
output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
|
162 |
-
output *= mask
|
163 |
-
output = self.pos_ff(output)
|
164 |
-
output *= mask
|
165 |
-
return output
|
166 |
-
|
167 |
-
|
168 |
-
class FFTransformer(nn.Module):
|
169 |
-
def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
|
170 |
-
dropout, dropatt, dropemb=0.0, embed_input=True,
|
171 |
-
n_embed=None, d_embed=None, padding_idx=0, pre_lnorm=False):
|
172 |
-
super(FFTransformer, self).__init__()
|
173 |
-
self.d_model = d_model
|
174 |
-
self.n_head = n_head
|
175 |
-
self.d_head = d_head
|
176 |
-
self.padding_idx = padding_idx
|
177 |
-
|
178 |
-
if embed_input:
|
179 |
-
self.word_emb = nn.Embedding(n_embed, d_embed or d_model,
|
180 |
-
padding_idx=self.padding_idx)
|
181 |
-
else:
|
182 |
-
self.word_emb = None
|
183 |
-
|
184 |
-
self.pos_emb = PositionalEmbedding(self.d_model)
|
185 |
-
self.drop = nn.Dropout(dropemb)
|
186 |
-
self.layers = nn.ModuleList()
|
187 |
-
|
188 |
-
for _ in range(n_layer):
|
189 |
-
self.layers.append(
|
190 |
-
TransformerLayer(
|
191 |
-
n_head, d_model, d_head, d_inner, kernel_size, dropout,
|
192 |
-
dropatt=dropatt, pre_lnorm=pre_lnorm)
|
193 |
-
)
|
194 |
-
|
195 |
-
def forward(self, dec_inp, seq_lens=None, conditioning=0):
|
196 |
-
if self.word_emb is None:
|
197 |
-
inp = dec_inp
|
198 |
-
mask = mask_from_lens(seq_lens).unsqueeze(2)
|
199 |
-
else:
|
200 |
-
inp = self.word_emb(dec_inp)
|
201 |
-
# [bsz x L x 1]
|
202 |
-
mask = (dec_inp != self.padding_idx).unsqueeze(2)
|
203 |
-
|
204 |
-
pos_seq = torch.arange(inp.size(1), device=inp.device).to(inp.dtype)
|
205 |
-
pos_emb = self.pos_emb(pos_seq) * mask
|
206 |
-
|
207 |
-
out = self.drop(inp + pos_emb + conditioning)
|
208 |
-
|
209 |
-
for layer in self.layers:
|
210 |
-
out = layer(out, mask=mask)
|
211 |
-
|
212 |
-
# out = self.drop(out)
|
213 |
-
return out, mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fastpitch/utils_trainplot_transformers/transformer_jit.py
DELETED
@@ -1,255 +0,0 @@
|
|
1 |
-
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
from typing import List, Optional
|
16 |
-
|
17 |
-
import torch
|
18 |
-
import torch.nn as nn
|
19 |
-
import torch.nn.functional as F
|
20 |
-
|
21 |
-
from common.utils import mask_from_lens
|
22 |
-
|
23 |
-
|
24 |
-
class PositionalEmbedding(nn.Module):
|
25 |
-
def __init__(self, demb):
|
26 |
-
super(PositionalEmbedding, self).__init__()
|
27 |
-
self.demb = demb
|
28 |
-
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
|
29 |
-
self.register_buffer('inv_freq', inv_freq)
|
30 |
-
|
31 |
-
def forward(self, pos_seq, bsz: Optional[int] = None):
|
32 |
-
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
|
33 |
-
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
|
34 |
-
if bsz is not None:
|
35 |
-
return pos_emb[None, :, :].expand(bsz, -1, -1)
|
36 |
-
else:
|
37 |
-
return pos_emb[None, :, :]
|
38 |
-
|
39 |
-
|
40 |
-
class PositionwiseFF(nn.Module):
|
41 |
-
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
|
42 |
-
super(PositionwiseFF, self).__init__()
|
43 |
-
|
44 |
-
self.d_model = d_model
|
45 |
-
self.d_inner = d_inner
|
46 |
-
self.dropout = dropout
|
47 |
-
|
48 |
-
self.CoreNet = nn.Sequential(
|
49 |
-
nn.Linear(d_model, d_inner), nn.ReLU(),
|
50 |
-
nn.Dropout(dropout),
|
51 |
-
nn.Linear(d_inner, d_model),
|
52 |
-
nn.Dropout(dropout),
|
53 |
-
)
|
54 |
-
|
55 |
-
self.layer_norm = nn.LayerNorm(d_model)
|
56 |
-
self.pre_lnorm = pre_lnorm
|
57 |
-
|
58 |
-
def forward(self, inp):
|
59 |
-
if self.pre_lnorm:
|
60 |
-
# layer normalization + positionwise feed-forward
|
61 |
-
core_out = self.CoreNet(self.layer_norm(inp))
|
62 |
-
|
63 |
-
# residual connection
|
64 |
-
output = core_out + inp
|
65 |
-
else:
|
66 |
-
# positionwise feed-forward
|
67 |
-
core_out = self.CoreNet(inp)
|
68 |
-
|
69 |
-
# residual connection + layer normalization
|
70 |
-
output = self.layer_norm(inp + core_out)
|
71 |
-
|
72 |
-
return output
|
73 |
-
|
74 |
-
|
75 |
-
class PositionwiseConvFF(nn.Module):
|
76 |
-
def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
|
77 |
-
super(PositionwiseConvFF, self).__init__()
|
78 |
-
|
79 |
-
self.d_model = d_model
|
80 |
-
self.d_inner = d_inner
|
81 |
-
self.dropout = dropout
|
82 |
-
|
83 |
-
self.CoreNet = nn.Sequential(
|
84 |
-
nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
|
85 |
-
nn.ReLU(),
|
86 |
-
# nn.Dropout(dropout), # worse convergence
|
87 |
-
nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
|
88 |
-
nn.Dropout(dropout),
|
89 |
-
)
|
90 |
-
self.layer_norm = nn.LayerNorm(d_model)
|
91 |
-
self.pre_lnorm = pre_lnorm
|
92 |
-
|
93 |
-
def forward(self, inp):
|
94 |
-
if self.pre_lnorm:
|
95 |
-
# layer normalization + positionwise feed-forward
|
96 |
-
core_out = inp.transpose(1, 2)
|
97 |
-
core_out = self.CoreNet(self.layer_norm(core_out))
|
98 |
-
core_out = core_out.transpose(1, 2)
|
99 |
-
|
100 |
-
# residual connection
|
101 |
-
output = core_out + inp
|
102 |
-
else:
|
103 |
-
# positionwise feed-forward
|
104 |
-
core_out = inp.transpose(1, 2)
|
105 |
-
core_out = self.CoreNet(core_out)
|
106 |
-
core_out = core_out.transpose(1, 2)
|
107 |
-
|
108 |
-
# residual connection + layer normalization
|
109 |
-
output = self.layer_norm(inp + core_out)
|
110 |
-
|
111 |
-
return output
|
112 |
-
|
113 |
-
|
114 |
-
class MultiHeadAttn(nn.Module):
|
115 |
-
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1,
|
116 |
-
pre_lnorm=False):
|
117 |
-
super(MultiHeadAttn, self).__init__()
|
118 |
-
|
119 |
-
self.n_head = n_head
|
120 |
-
self.d_model = d_model
|
121 |
-
self.d_head = d_head
|
122 |
-
self.scale = 1 / (d_head ** 0.5)
|
123 |
-
self.dropout = dropout
|
124 |
-
self.pre_lnorm = pre_lnorm
|
125 |
-
|
126 |
-
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
|
127 |
-
self.drop = nn.Dropout(dropout)
|
128 |
-
self.dropatt = nn.Dropout(dropatt)
|
129 |
-
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
130 |
-
self.layer_norm = nn.LayerNorm(d_model)
|
131 |
-
|
132 |
-
|
133 |
-
def forward(self, inp, attn_mask: Optional[torch.Tensor] = None):
|
134 |
-
residual = inp
|
135 |
-
|
136 |
-
if self.pre_lnorm:
|
137 |
-
# layer normalization
|
138 |
-
inp = self.layer_norm(inp)
|
139 |
-
|
140 |
-
n_head, d_head = self.n_head, self.d_head
|
141 |
-
|
142 |
-
head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=-1)
|
143 |
-
head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
|
144 |
-
head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
|
145 |
-
head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
|
146 |
-
|
147 |
-
q = head_q.permute(0, 2, 1, 3).reshape(-1, inp.size(1), d_head)
|
148 |
-
k = head_k.permute(0, 2, 1, 3).reshape(-1, inp.size(1), d_head)
|
149 |
-
v = head_v.permute(0, 2, 1, 3).reshape(-1, inp.size(1), d_head)
|
150 |
-
|
151 |
-
attn_score = torch.bmm(q, k.transpose(1, 2))
|
152 |
-
attn_score.mul_(self.scale)
|
153 |
-
|
154 |
-
if attn_mask is not None:
|
155 |
-
attn_mask = attn_mask.unsqueeze(1)
|
156 |
-
attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
|
157 |
-
attn_score.masked_fill_(attn_mask, -float('inf'))
|
158 |
-
|
159 |
-
attn_prob = F.softmax(attn_score, dim=2)
|
160 |
-
attn_prob = self.dropatt(attn_prob)
|
161 |
-
attn_vec = torch.bmm(attn_prob, v)
|
162 |
-
|
163 |
-
attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
|
164 |
-
attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(
|
165 |
-
inp.size(0), inp.size(1), n_head * d_head)
|
166 |
-
|
167 |
-
# linear projection
|
168 |
-
attn_out = self.o_net(attn_vec)
|
169 |
-
attn_out = self.drop(attn_out)
|
170 |
-
|
171 |
-
if self.pre_lnorm:
|
172 |
-
# residual connection
|
173 |
-
output = residual + attn_out
|
174 |
-
else:
|
175 |
-
# residual connection + layer normalization
|
176 |
-
|
177 |
-
# XXX Running TorchScript on 20.02 and 20.03 containers crashes here
|
178 |
-
# XXX Works well with 20.01-py3 container.
|
179 |
-
# XXX dirty fix is:
|
180 |
-
# XXX output = self.layer_norm(residual + attn_out).half()
|
181 |
-
output = self.layer_norm(residual + attn_out)
|
182 |
-
|
183 |
-
return output
|
184 |
-
|
185 |
-
|
186 |
-
class TransformerLayer(nn.Module):
|
187 |
-
def __init__(self, n_head, d_model, d_head, d_inner, kernel_size, dropout,
|
188 |
-
**kwargs):
|
189 |
-
super(TransformerLayer, self).__init__()
|
190 |
-
|
191 |
-
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
|
192 |
-
self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout,
|
193 |
-
pre_lnorm=kwargs.get('pre_lnorm'))
|
194 |
-
|
195 |
-
def forward(self, dec_inp, mask):
|
196 |
-
output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
|
197 |
-
output *= mask
|
198 |
-
output = self.pos_ff(output)
|
199 |
-
output *= mask
|
200 |
-
return output
|
201 |
-
|
202 |
-
|
203 |
-
class FFTransformer(nn.Module):
|
204 |
-
def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
|
205 |
-
dropout, dropatt, dropemb=0.0, embed_input=True,
|
206 |
-
n_embed=None, d_embed=None, padding_idx=0, pre_lnorm=False):
|
207 |
-
super(FFTransformer, self).__init__()
|
208 |
-
self.d_model = d_model
|
209 |
-
self.n_head = n_head
|
210 |
-
self.d_head = d_head
|
211 |
-
self.padding_idx = padding_idx
|
212 |
-
self.n_embed = n_embed
|
213 |
-
|
214 |
-
self.embed_input = embed_input
|
215 |
-
if embed_input:
|
216 |
-
print(padding_idx) #########################################
|
217 |
-
self.word_emb = nn.Embedding(n_embed, d_embed or d_model,
|
218 |
-
padding_idx=self.padding_idx)
|
219 |
-
else:
|
220 |
-
self.word_emb = nn.Identity()
|
221 |
-
|
222 |
-
self.pos_emb = PositionalEmbedding(self.d_model)
|
223 |
-
self.drop = nn.Dropout(dropemb)
|
224 |
-
self.layers = nn.ModuleList()
|
225 |
-
|
226 |
-
for _ in range(n_layer):
|
227 |
-
self.layers.append(
|
228 |
-
TransformerLayer(
|
229 |
-
n_head, d_model, d_head, d_inner, kernel_size, dropout,
|
230 |
-
dropatt=dropatt, pre_lnorm=pre_lnorm)
|
231 |
-
)
|
232 |
-
|
233 |
-
def forward(self, dec_inp, seq_lens: Optional[torch.Tensor] = None,
|
234 |
-
conditioning: Optional[torch.Tensor] = None):
|
235 |
-
if not self.embed_input:
|
236 |
-
inp = dec_inp
|
237 |
-
assert seq_lens is not None
|
238 |
-
mask = mask_from_lens(seq_lens).unsqueeze(2)
|
239 |
-
else:
|
240 |
-
inp = self.word_emb(dec_inp)
|
241 |
-
# [bsz x L x 1]
|
242 |
-
mask = (dec_inp != self.padding_idx).unsqueeze(2)
|
243 |
-
|
244 |
-
pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype)
|
245 |
-
pos_emb = self.pos_emb(pos_seq) * mask
|
246 |
-
if conditioning is not None:
|
247 |
-
out = self.drop(inp + pos_emb + conditioning)
|
248 |
-
else:
|
249 |
-
out = self.drop(inp + pos_emb)
|
250 |
-
|
251 |
-
for layer in self.layers:
|
252 |
-
out = layer(out, mask=mask)
|
253 |
-
|
254 |
-
# out = self.drop(out)
|
255 |
-
return out, mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fastpitch/utils_trainplot_transformers/utils.py
DELETED
@@ -1,291 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
# MIT License
|
16 |
-
#
|
17 |
-
# Copyright (c) 2020 Jungil Kong
|
18 |
-
#
|
19 |
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
20 |
-
# of this software and associated documentation files (the "Software"), to deal
|
21 |
-
# in the Software without restriction, including without limitation the rights
|
22 |
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
23 |
-
# copies of the Software, and to permit persons to whom the Software is
|
24 |
-
# furnished to do so, subject to the following conditions:
|
25 |
-
#
|
26 |
-
# The above copyright notice and this permission notice shall be included in all
|
27 |
-
# copies or substantial portions of the Software.
|
28 |
-
#
|
29 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
30 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
31 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
32 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
33 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
34 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
35 |
-
# SOFTWARE.
|
36 |
-
|
37 |
-
# The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
|
38 |
-
# init_weights, get_padding, AttrDict
|
39 |
-
|
40 |
-
import ctypes
|
41 |
-
import glob
|
42 |
-
import os
|
43 |
-
import re
|
44 |
-
import shutil
|
45 |
-
import warnings
|
46 |
-
from collections import defaultdict, OrderedDict
|
47 |
-
from pathlib import Path
|
48 |
-
from typing import Optional
|
49 |
-
|
50 |
-
import librosa
|
51 |
-
import numpy as np
|
52 |
-
|
53 |
-
import torch
|
54 |
-
import torch.distributed as dist
|
55 |
-
from scipy.io.wavfile import read
|
56 |
-
|
57 |
-
|
58 |
-
def mask_from_lens(lens, max_len: Optional[int] = None):
|
59 |
-
if max_len is None:
|
60 |
-
max_len = lens.max()
|
61 |
-
ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
|
62 |
-
mask = torch.lt(ids, lens.unsqueeze(1))
|
63 |
-
return mask
|
64 |
-
|
65 |
-
|
66 |
-
def load_wav(full_path, torch_tensor=False):
|
67 |
-
import soundfile # flac
|
68 |
-
data, sampling_rate = soundfile.read(full_path, dtype='int16')
|
69 |
-
if torch_tensor:
|
70 |
-
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
71 |
-
else:
|
72 |
-
return data, sampling_rate
|
73 |
-
|
74 |
-
|
75 |
-
def load_wav_to_torch(full_path, force_sampling_rate=None):
|
76 |
-
if force_sampling_rate is not None:
|
77 |
-
data, sampling_rate = librosa.load(full_path, sr=force_sampling_rate)
|
78 |
-
else:
|
79 |
-
sampling_rate, data = read(full_path)
|
80 |
-
|
81 |
-
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
82 |
-
|
83 |
-
|
84 |
-
def load_filepaths_and_text(dataset_path, fnames, has_speakers=False, split="|"):
|
85 |
-
def split_line(root, line):
|
86 |
-
parts = line.strip().split(split)
|
87 |
-
if has_speakers:
|
88 |
-
paths, non_paths = parts[:-2], parts[-2:]
|
89 |
-
else:
|
90 |
-
paths, non_paths = parts[:-1], parts[-1:]
|
91 |
-
return tuple(str(Path(root, p)) for p in paths) + tuple(non_paths)
|
92 |
-
|
93 |
-
fpaths_and_text = []
|
94 |
-
for fname in fnames:
|
95 |
-
with open(fname, encoding='utf-8') as f:
|
96 |
-
fpaths_and_text += [split_line(dataset_path, line) for line in f]
|
97 |
-
return fpaths_and_text
|
98 |
-
|
99 |
-
|
100 |
-
def to_gpu(x):
|
101 |
-
x = x.contiguous()
|
102 |
-
return x.cuda(non_blocking=True) if torch.cuda.is_available() else x
|
103 |
-
|
104 |
-
|
105 |
-
def l2_promote():
|
106 |
-
_libcudart = ctypes.CDLL('libcudart.so')
|
107 |
-
# Set device limit on the current device
|
108 |
-
# cudaLimitMaxL2FetchGranularity = 0x05
|
109 |
-
pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))
|
110 |
-
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
|
111 |
-
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
|
112 |
-
assert pValue.contents.value == 128
|
113 |
-
|
114 |
-
|
115 |
-
def prepare_tmp(path):
|
116 |
-
if path is None:
|
117 |
-
return
|
118 |
-
p = Path(path)
|
119 |
-
if p.is_dir():
|
120 |
-
warnings.warn(f'{p} exists. Removing...')
|
121 |
-
shutil.rmtree(p, ignore_errors=True)
|
122 |
-
p.mkdir(parents=False, exist_ok=False)
|
123 |
-
|
124 |
-
|
125 |
-
def print_once(*msg):
|
126 |
-
if not dist.is_initialized() or dist.get_rank() == 0:
|
127 |
-
print(*msg)
|
128 |
-
|
129 |
-
|
130 |
-
def init_weights(m, mean=0.0, std=0.01):
|
131 |
-
classname = m.__class__.__name__
|
132 |
-
if classname.find("Conv") != -1:
|
133 |
-
m.weight.data.normal_(mean, std)
|
134 |
-
|
135 |
-
|
136 |
-
def get_padding(kernel_size, dilation=1):
|
137 |
-
return int((kernel_size*dilation - dilation)/2)
|
138 |
-
|
139 |
-
|
140 |
-
class AttrDict(dict):
|
141 |
-
def __init__(self, *args, **kwargs):
|
142 |
-
super(AttrDict, self).__init__(*args, **kwargs)
|
143 |
-
self.__dict__ = self
|
144 |
-
|
145 |
-
|
146 |
-
class DefaultAttrDict(defaultdict):
|
147 |
-
def __init__(self, *args, **kwargs):
|
148 |
-
super(DefaultAttrDict, self).__init__(*args, **kwargs)
|
149 |
-
self.__dict__ = self
|
150 |
-
|
151 |
-
def __getattr__(self, item):
|
152 |
-
return self[item]
|
153 |
-
|
154 |
-
|
155 |
-
class BenchmarkStats:
|
156 |
-
""" Tracks statistics used for benchmarking. """
|
157 |
-
def __init__(self):
|
158 |
-
self.num_frames = []
|
159 |
-
self.losses = []
|
160 |
-
self.mel_losses = []
|
161 |
-
self.took = []
|
162 |
-
|
163 |
-
def update(self, num_frames, losses, mel_losses, took):
|
164 |
-
self.num_frames.append(num_frames)
|
165 |
-
self.losses.append(losses)
|
166 |
-
self.mel_losses.append(mel_losses)
|
167 |
-
self.took.append(took)
|
168 |
-
|
169 |
-
def get(self, n_epochs):
|
170 |
-
frames_s = sum(self.num_frames[-n_epochs:]) / sum(self.took[-n_epochs:])
|
171 |
-
return {'frames/s': frames_s,
|
172 |
-
'loss': np.mean(self.losses[-n_epochs:]),
|
173 |
-
'mel_loss': np.mean(self.mel_losses[-n_epochs:]),
|
174 |
-
'took': np.mean(self.took[-n_epochs:]),
|
175 |
-
'benchmark_epochs_num': n_epochs}
|
176 |
-
|
177 |
-
def __len__(self):
|
178 |
-
return len(self.losses)
|
179 |
-
|
180 |
-
|
181 |
-
class Checkpointer:
|
182 |
-
|
183 |
-
def __init__(self, save_dir, keep_milestones=[]):
|
184 |
-
self.save_dir = save_dir
|
185 |
-
self.keep_milestones = keep_milestones
|
186 |
-
|
187 |
-
find = lambda name: [
|
188 |
-
(int(re.search("_(\d+).pt", fn).group(1)), fn)
|
189 |
-
for fn in glob.glob(f"{save_dir}/{name}_checkpoint_*.pt")]
|
190 |
-
|
191 |
-
tracked = sorted(find("FastPitch"), key=lambda t: t[0])
|
192 |
-
self.tracked = OrderedDict(tracked)
|
193 |
-
|
194 |
-
def last_checkpoint(self, output):
|
195 |
-
|
196 |
-
def corrupted(fpath):
|
197 |
-
try:
|
198 |
-
torch.load(fpath, map_location="cpu")
|
199 |
-
return False
|
200 |
-
except:
|
201 |
-
warnings.warn(f"Cannot load {fpath}")
|
202 |
-
return True
|
203 |
-
|
204 |
-
saved = sorted(
|
205 |
-
glob.glob(f"{output}/FastPitch_checkpoint_*.pt"),
|
206 |
-
key=lambda f: int(re.search("_(\d+).pt", f).group(1)))
|
207 |
-
|
208 |
-
if len(saved) >= 1 and not corrupted(saved[-1]):
|
209 |
-
return saved[-1]
|
210 |
-
elif len(saved) >= 2:
|
211 |
-
return saved[-2]
|
212 |
-
else:
|
213 |
-
return None
|
214 |
-
|
215 |
-
def maybe_load(self, model, optimizer, scaler, train_state, args,
|
216 |
-
ema_model=None):
|
217 |
-
|
218 |
-
assert args.checkpoint_path is None or args.resume is False, (
|
219 |
-
"Specify a single checkpoint source")
|
220 |
-
|
221 |
-
fpath = None
|
222 |
-
if args.checkpoint_path is not None:
|
223 |
-
fpath = args.checkpoint_path
|
224 |
-
self.tracked = OrderedDict() # Do not track/delete prev ckpts
|
225 |
-
elif args.resume:
|
226 |
-
fpath = self.last_checkpoint(args.output)
|
227 |
-
|
228 |
-
if fpath is None:
|
229 |
-
return
|
230 |
-
|
231 |
-
print_once(f"Loading model and optimizer state from {fpath}")
|
232 |
-
ckpt = torch.load(fpath, map_location="cpu")
|
233 |
-
train_state["epoch"] = ckpt["epoch"] + 1
|
234 |
-
train_state["total_iter"] = ckpt["iteration"]
|
235 |
-
|
236 |
-
no_pref = lambda sd: {re.sub("^module.", "", k): v for k, v in sd.items()}
|
237 |
-
unwrap = lambda m: getattr(m, "module", m)
|
238 |
-
|
239 |
-
unwrap(model).load_state_dict(no_pref(ckpt["state_dict"]))
|
240 |
-
|
241 |
-
if ema_model is not None:
|
242 |
-
unwrap(ema_model).load_state_dict(no_pref(ckpt["ema_state_dict"]))
|
243 |
-
|
244 |
-
optimizer.load_state_dict(ckpt["optimizer"])
|
245 |
-
|
246 |
-
if "scaler" in ckpt:
|
247 |
-
scaler.load_state_dict(ckpt["scaler"])
|
248 |
-
else:
|
249 |
-
warnings.warn("AMP scaler state missing from the checkpoint.")
|
250 |
-
|
251 |
-
def maybe_save(self, args, model, ema_model, optimizer, scaler, epoch,
|
252 |
-
total_iter, config):
|
253 |
-
|
254 |
-
intermediate = (args.epochs_per_checkpoint > 0
|
255 |
-
and epoch % args.epochs_per_checkpoint == 0)
|
256 |
-
final = epoch == args.epochs
|
257 |
-
|
258 |
-
if not intermediate and not final and epoch not in self.keep_milestones:
|
259 |
-
return
|
260 |
-
|
261 |
-
rank = 0
|
262 |
-
if dist.is_initialized():
|
263 |
-
dist.barrier()
|
264 |
-
rank = dist.get_rank()
|
265 |
-
|
266 |
-
if rank != 0:
|
267 |
-
return
|
268 |
-
|
269 |
-
unwrap = lambda m: getattr(m, "module", m)
|
270 |
-
ckpt = {"epoch": epoch,
|
271 |
-
"iteration": total_iter,
|
272 |
-
"config": config,
|
273 |
-
"train_setup": args.__dict__,
|
274 |
-
"state_dict": unwrap(model).state_dict(),
|
275 |
-
"optimizer": optimizer.state_dict(),
|
276 |
-
"scaler": scaler.state_dict()}
|
277 |
-
if ema_model is not None:
|
278 |
-
ckpt["ema_state_dict"] = unwrap(ema_model).state_dict()
|
279 |
-
|
280 |
-
fpath = Path(args.output, f"FastPitch_checkpoint_{epoch}.pt")
|
281 |
-
print(f"Saving model and optimizer state at epoch {epoch} to {fpath}")
|
282 |
-
torch.save(ckpt, fpath)
|
283 |
-
|
284 |
-
# Remove old checkpoints; keep milestones and the last two
|
285 |
-
self.tracked[epoch] = fpath
|
286 |
-
for epoch in set(list(self.tracked)[:-2]) - set(self.keep_milestones):
|
287 |
-
try:
|
288 |
-
os.remove(self.tracked[epoch])
|
289 |
-
except:
|
290 |
-
pass
|
291 |
-
del self.tracked[epoch]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gradio_gui.py
DELETED
@@ -1,74 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import syn_hifigan as syn
|
3 |
-
#import syn_k_univnet_multi as syn
|
4 |
-
import os, tempfile
|
5 |
-
|
6 |
-
languages = {"South Sámi":0,
|
7 |
-
"North Sámi":1,
|
8 |
-
"Lule Sámi":2}
|
9 |
-
|
10 |
-
speakers={"aj0": 0,
|
11 |
-
"aj1": 1,
|
12 |
-
"am": 2,
|
13 |
-
"bi": 3,
|
14 |
-
"kd": 4,
|
15 |
-
"ln": 5,
|
16 |
-
"lo": 6,
|
17 |
-
"ms": 7,
|
18 |
-
"mu": 8,
|
19 |
-
"sa": 9
|
20 |
-
}
|
21 |
-
public=False
|
22 |
-
|
23 |
-
tempdir = tempfile.gettempdir()
|
24 |
-
|
25 |
-
tts = syn.Synthesizer()
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
def speak(text, language,speaker,l_weight, s_weight, pace, postfilter): #pitch_shift,pitch_std):
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
# text frontend not implemented...
|
34 |
-
text = text.replace("...", "…")
|
35 |
-
print(speakers[speaker])
|
36 |
-
audio = tts.speak(text, output_file=f'{tempdir}/tmp', lang=languages[language],
|
37 |
-
spkr=speakers[speaker], l_weight=l_weight, s_weight=s_weight,
|
38 |
-
pace=pace, clarity=postfilter)
|
39 |
-
|
40 |
-
if not public:
|
41 |
-
try:
|
42 |
-
os.system("play "+tempdir+"/tmp.wav &")
|
43 |
-
except:
|
44 |
-
pass
|
45 |
-
|
46 |
-
return (22050, audio)
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
controls = []
|
51 |
-
controls.append(gr.Textbox(label="text", value="Suohtas duinna deaivvadit."))
|
52 |
-
controls.append(gr.Dropdown(list(languages.keys()), label="language", value="North Sámi"))
|
53 |
-
controls.append(gr.Dropdown(list(speakers.keys()), label="speaker", value="ms"))
|
54 |
-
|
55 |
-
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="language weight"))
|
56 |
-
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="speaker weight"))
|
57 |
-
|
58 |
-
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1.0, label="speech rate"))
|
59 |
-
controls.append(gr.Slider(minimum=0., maximum=2, step=0.05, value=1.0, label="post-processing"))
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
tts_gui = gr.Interface(
|
65 |
-
fn=speak,
|
66 |
-
inputs=controls,
|
67 |
-
outputs= gr.Audio(label="output"),
|
68 |
-
live=False
|
69 |
-
|
70 |
-
)
|
71 |
-
|
72 |
-
|
73 |
-
if __name__ == "__main__":
|
74 |
-
tts_gui.launch(share=public)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gradio_gui_katri.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
#import syn_hifigan as syn
|
3 |
-
import syn_k_univnet_multi as syn
|
4 |
-
import os, tempfile
|
5 |
-
|
6 |
-
languages = {"South Sámi":0,
|
7 |
-
"North Sámi":1,
|
8 |
-
"Lule Sámi":2}
|
9 |
-
|
10 |
-
speakers={"aj0": 0,
|
11 |
-
"aj1": 1,
|
12 |
-
"am": 2,
|
13 |
-
"bi": 3,
|
14 |
-
"kd": 4,
|
15 |
-
"ln": 5,
|
16 |
-
"lo": 6,
|
17 |
-
"ms": 7,
|
18 |
-
"mu": 8,
|
19 |
-
"sa": 9
|
20 |
-
}
|
21 |
-
public=True
|
22 |
-
|
23 |
-
tempdir = tempfile.gettempdir()
|
24 |
-
|
25 |
-
tts = syn.Synthesizer()
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
def speak(text, language,speaker,l_weight, s_weight, pace): #pitch_shift,pitch_std):
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
# text frontend not implemented...
|
34 |
-
text = text.replace("...", "…")
|
35 |
-
print(speakers[speaker])
|
36 |
-
audio = tts.speak(text, output_file=f'{tempdir}/tmp', lang=languages[language],
|
37 |
-
spkr=speakers[speaker], l_weight=l_weight, s_weight=s_weight,
|
38 |
-
pace=pace)
|
39 |
-
|
40 |
-
if not public:
|
41 |
-
try:
|
42 |
-
os.system("play "+tempdir+"/tmp.wav &")
|
43 |
-
except:
|
44 |
-
pass
|
45 |
-
|
46 |
-
return (22050, audio)
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
controls = []
|
51 |
-
controls.append(gr.Textbox(label="text", value="Suohtas duinna deaivvadit."))
|
52 |
-
controls.append(gr.Dropdown(list(languages.keys()), label="language", value="North Sámi"))
|
53 |
-
controls.append(gr.Dropdown(list(speakers.keys()), label="speaker", value="ms"))
|
54 |
-
|
55 |
-
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="Language weight"))
|
56 |
-
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="Speaker weight"))
|
57 |
-
#controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1.0, label="Pitch variance"))
|
58 |
-
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1.0, label="speech rate"))
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
tts_gui = gr.Interface(
|
64 |
-
fn=speak,
|
65 |
-
inputs=controls,
|
66 |
-
outputs= gr.Audio(label="output"),
|
67 |
-
live=False
|
68 |
-
|
69 |
-
)
|
70 |
-
|
71 |
-
|
72 |
-
if __name__ == "__main__":
|
73 |
-
tts_gui.launch(share=public)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prepare_dataset.py
DELETED
@@ -1,180 +0,0 @@
|
|
1 |
-
# *****************************************************************************
|
2 |
-
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
-
#
|
4 |
-
# Redistribution and use in source and binary forms, with or without
|
5 |
-
# modification, are permitted provided that the following conditions are met:
|
6 |
-
# * Redistributions of source code must retain the above copyright
|
7 |
-
# notice, this list of conditions and the following disclaimer.
|
8 |
-
# * Redistributions in binary form must reproduce the above copyright
|
9 |
-
# notice, this list of conditions and the following disclaimer in the
|
10 |
-
# documentation and/or other materials provided with the distribution.
|
11 |
-
# * Neither the name of the NVIDIA CORPORATION nor the
|
12 |
-
# names of its contributors may be used to endorse or promote products
|
13 |
-
# derived from this software without specific prior written permission.
|
14 |
-
#
|
15 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
16 |
-
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
17 |
-
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
-
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
19 |
-
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
-
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
-
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
22 |
-
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
-
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
-
#
|
26 |
-
# *****************************************************************************
|
27 |
-
|
28 |
-
import argparse
|
29 |
-
import time
|
30 |
-
from pathlib import Path
|
31 |
-
|
32 |
-
import torch
|
33 |
-
import tqdm
|
34 |
-
import dllogger as DLLogger
|
35 |
-
from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
|
36 |
-
from torch.utils.data import DataLoader
|
37 |
-
|
38 |
-
from fastpitch.data_function import TTSCollate, TTSDataset
|
39 |
-
|
40 |
-
|
41 |
-
def parse_args(parser):
|
42 |
-
"""
|
43 |
-
Parse commandline arguments.
|
44 |
-
"""
|
45 |
-
parser.add_argument('-d', '--dataset-path', type=str,
|
46 |
-
default='./', help='Path to dataset')
|
47 |
-
parser.add_argument('--wav-text-filelists', required=True, nargs='+',
|
48 |
-
type=str, help='Files with audio paths and text')
|
49 |
-
parser.add_argument('--extract-mels', action='store_true',
|
50 |
-
help='Calculate spectrograms from .wav files')
|
51 |
-
parser.add_argument('--extract-pitch', action='store_true',
|
52 |
-
help='Extract pitch')
|
53 |
-
parser.add_argument('--save-alignment-priors', action='store_true',
|
54 |
-
help='Pre-calculate diagonal matrices of alignment of text to audio')
|
55 |
-
parser.add_argument('--log-file', type=str, default='preproc_log.json',
|
56 |
-
help='Filename for logging')
|
57 |
-
parser.add_argument('--n-speakers', type=int, default=1)
|
58 |
-
parser.add_argument('--n-languages', type=int, default=1)
|
59 |
-
# Mel extraction
|
60 |
-
parser.add_argument('--max-wav-value', default=32768.0, type=float,
|
61 |
-
help='Maximum audiowave value')
|
62 |
-
parser.add_argument('--sampling-rate', default=22050, type=int,
|
63 |
-
help='Sampling rate')
|
64 |
-
parser.add_argument('--filter-length', default=1024, type=int,
|
65 |
-
help='Filter length')
|
66 |
-
parser.add_argument('--hop-length', default=256, type=int,
|
67 |
-
help='Hop (stride) length')
|
68 |
-
parser.add_argument('--win-length', default=1024, type=int,
|
69 |
-
help='Window length')
|
70 |
-
parser.add_argument('--mel-fmin', default=0.0, type=float,
|
71 |
-
help='Minimum mel frequency')
|
72 |
-
parser.add_argument('--mel-fmax', default=8000.0, type=float,
|
73 |
-
help='Maximum mel frequency')
|
74 |
-
parser.add_argument('--n-mel-channels', type=int, default=80)
|
75 |
-
# Pitch extraction
|
76 |
-
parser.add_argument('--f0-method', default='pyin', type=str,
|
77 |
-
choices=['pyin'], help='F0 estimation method')
|
78 |
-
parser.add_argument('--pitch-mean', default='214', type=float, ###
|
79 |
-
help='F0 estimation method')
|
80 |
-
parser.add_argument('--pitch-std', default='65', type=float, ####
|
81 |
-
help='F0 estimation method')
|
82 |
-
# Performance
|
83 |
-
parser.add_argument('-b', '--batch-size', default=1, type=int)
|
84 |
-
parser.add_argument('--n-workers', type=int, default=16)
|
85 |
-
return parser
|
86 |
-
|
87 |
-
|
88 |
-
def main():
|
89 |
-
parser = argparse.ArgumentParser(description='FastPitch Data Pre-processing')
|
90 |
-
parser = parse_args(parser)
|
91 |
-
args, unk_args = parser.parse_known_args()
|
92 |
-
if len(unk_args) > 0:
|
93 |
-
raise ValueError(f'Invalid options {unk_args}')
|
94 |
-
|
95 |
-
DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, Path(args.dataset_path, args.log_file)),
|
96 |
-
StdOutBackend(Verbosity.VERBOSE)])
|
97 |
-
for k, v in vars(args).items():
|
98 |
-
DLLogger.log(step="PARAMETER", data={k: v})
|
99 |
-
DLLogger.flush()
|
100 |
-
|
101 |
-
if args.extract_mels:
|
102 |
-
Path(args.dataset_path, 'mels').mkdir(parents=False, exist_ok=True)
|
103 |
-
|
104 |
-
if args.extract_pitch:
|
105 |
-
Path(args.dataset_path, 'pitch').mkdir(parents=False, exist_ok=True)
|
106 |
-
|
107 |
-
if args.save_alignment_priors:
|
108 |
-
Path(args.dataset_path, 'alignment_priors').mkdir(parents=False, exist_ok=True)
|
109 |
-
|
110 |
-
for filelist in args.wav_text_filelists:
|
111 |
-
|
112 |
-
print(f'Processing {filelist}...')
|
113 |
-
|
114 |
-
dataset = TTSDataset(
|
115 |
-
args.dataset_path,
|
116 |
-
filelist,
|
117 |
-
text_cleaners=['basic_cleaners'],
|
118 |
-
n_mel_channels=args.n_mel_channels,
|
119 |
-
p_arpabet=0.0,
|
120 |
-
n_speakers=args.n_speakers,
|
121 |
-
n_languages=args.n_languages,
|
122 |
-
load_mel_from_disk=False,
|
123 |
-
load_pitch_from_disk=False,
|
124 |
-
pitch_mean=args.pitch_mean,
|
125 |
-
pitch_std=args.pitch_std,
|
126 |
-
max_wav_value=args.max_wav_value,
|
127 |
-
sampling_rate=args.sampling_rate,
|
128 |
-
filter_length=args.filter_length,
|
129 |
-
hop_length=args.hop_length,
|
130 |
-
win_length=args.win_length,
|
131 |
-
mel_fmin=args.mel_fmin,
|
132 |
-
mel_fmax=args.mel_fmax,
|
133 |
-
betabinomial_online_dir=None,
|
134 |
-
pitch_online_dir=None,
|
135 |
-
pitch_online_method=args.f0_method)
|
136 |
-
|
137 |
-
data_loader = DataLoader(
|
138 |
-
dataset,
|
139 |
-
batch_size=args.batch_size,
|
140 |
-
shuffle=False,
|
141 |
-
sampler=None,
|
142 |
-
num_workers=args.n_workers,
|
143 |
-
collate_fn=TTSCollate(),
|
144 |
-
pin_memory=False,
|
145 |
-
drop_last=False)
|
146 |
-
|
147 |
-
all_filenames = set()
|
148 |
-
for i, batch in enumerate(tqdm.tqdm(data_loader)):
|
149 |
-
tik = time.time()
|
150 |
-
|
151 |
-
_, input_lens, mels, mel_lens, _, pitch, _, _, _, attn_prior, fpaths = batch
|
152 |
-
|
153 |
-
# Ensure filenames are unique
|
154 |
-
for p in fpaths:
|
155 |
-
fname = Path(p).name
|
156 |
-
if fname in all_filenames:
|
157 |
-
raise ValueError(f'Filename is not unique: {fname}')
|
158 |
-
all_filenames.add(fname)
|
159 |
-
|
160 |
-
if args.extract_mels:
|
161 |
-
for j, mel in enumerate(mels):
|
162 |
-
fname = Path(fpaths[j]).with_suffix('.pt').name
|
163 |
-
fpath = Path(args.dataset_path, 'mels', fname)
|
164 |
-
torch.save(mel[:, :mel_lens[j]], fpath)
|
165 |
-
|
166 |
-
if args.extract_pitch:
|
167 |
-
for j, p in enumerate(pitch):
|
168 |
-
fname = Path(fpaths[j]).with_suffix('.pt').name
|
169 |
-
fpath = Path(args.dataset_path, 'pitch', fname)
|
170 |
-
torch.save(p[:mel_lens[j]], fpath)
|
171 |
-
|
172 |
-
if args.save_alignment_priors:
|
173 |
-
for j, prior in enumerate(attn_prior):
|
174 |
-
fname = Path(fpaths[j]).with_suffix('.pt').name
|
175 |
-
fpath = Path(args.dataset_path, 'alignment_priors', fname)
|
176 |
-
torch.save(prior[:mel_lens[j], :input_lens[j]], fpath)
|
177 |
-
|
178 |
-
|
179 |
-
if __name__ == '__main__':
|
180 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_training_cluster_s.sh
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
#SBATCH --job-name=train_fastpitch
|
3 |
-
#SBATCH --account=nn9866k
|
4 |
-
#SBATCH --time=11:50:00
|
5 |
-
#SBATCH --mem=16G
|
6 |
-
#SBATCH --partition=accel
|
7 |
-
#SBATCH --gres=gpu:1
|
8 |
-
|
9 |
-
# == Logging
|
10 |
-
|
11 |
-
#SBATCH --error="log_err" # Save the error messages
|
12 |
-
#SBATCH --output="log_out" # Save the stdout
|
13 |
-
|
14 |
-
## Set up job environment:
|
15 |
-
# set -o errexit # Exit the script on any error
|
16 |
-
# set -o nounset # Treat any unset variables as an error
|
17 |
-
|
18 |
-
## Activate environment
|
19 |
-
# source ~/.bashrc
|
20 |
-
|
21 |
-
eval "$(conda shell.bash hook)"
|
22 |
-
conda activate fastpitch
|
23 |
-
|
24 |
-
# Setup monitoring
|
25 |
-
nvidia-smi --query-gpu=timestamp,utilization.gpu,utilization.memory \
|
26 |
-
--format=csv --loop=1 > "gpu_util-$SLURM_JOB_ID.csv" &
|
27 |
-
NVIDIA_MONITOR_PID=$! # Capture PID of monitoring process
|
28 |
-
|
29 |
-
# Run our computation
|
30 |
-
bash scripts/train_2.sh
|
31 |
-
|
32 |
-
# After computation stop monitoring
|
33 |
-
kill -SIGINT "$NVIDIA_MONITOR_PID"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/docker/build.sh
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
#!/usr/bin/env bash
|
2 |
-
|
3 |
-
docker build . -t fastpitch:latest
|
|
|
|
|
|
|
|
scripts/docker/interactive.sh
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
#!/usr/bin/env bash
|
2 |
-
|
3 |
-
PORT=${PORT:-8888}
|
4 |
-
|
5 |
-
docker run --gpus=all -it --rm -e CUDA_VISIBLE_DEVICES --ipc=host -p $PORT:$PORT -v $PWD:/workspace/fastpitch/ fastpitch:latest bash
|
|
|
|
|
|
|
|
|
|
|
|
scripts/download_cmudict.sh
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
#!/usr/bin/env bash
|
2 |
-
|
3 |
-
set -e
|
4 |
-
|
5 |
-
: ${CMUDICT_DIR:="cmudict"}
|
6 |
-
|
7 |
-
if [ ! -f $CMUDICT_DIR/cmudict-0.7b ]; then
|
8 |
-
echo "Downloading cmudict-0.7b ..."
|
9 |
-
wget https://github.com/Alexir/CMUdict/raw/master/cmudict-0.7b -qO $CMUDICT_DIR/cmudict-0.7b
|
10 |
-
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/download_dataset.sh
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
#!/usr/bin/env bash
|
2 |
-
|
3 |
-
set -e
|
4 |
-
|
5 |
-
scripts/download_cmudict.sh
|
6 |
-
|
7 |
-
DATA_DIR="LJSpeech-1.1"
|
8 |
-
LJS_ARCH="LJSpeech-1.1.tar.bz2"
|
9 |
-
LJS_URL="http://data.keithito.com/data/speech/${LJS_ARCH}"
|
10 |
-
|
11 |
-
if [ ! -d ${DATA_DIR} ]; then
|
12 |
-
echo "Downloading ${LJS_ARCH} ..."
|
13 |
-
wget -q ${LJS_URL}
|
14 |
-
echo "Extracting ${LJS_ARCH} ..."
|
15 |
-
tar jxvf ${LJS_ARCH}
|
16 |
-
rm -f ${LJS_ARCH}
|
17 |
-
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/download_models.sh
DELETED
@@ -1,63 +0,0 @@
|
|
1 |
-
#!/usr/bin/env bash
|
2 |
-
|
3 |
-
set -e
|
4 |
-
|
5 |
-
MODEL_NAMES="$@"
|
6 |
-
[ -z "$MODEL_NAMES" ] && { echo "Usage: $0 [fastpitch|waveglow|hifigan|hifigan-finetuned-fastpitch]"; exit 1; }
|
7 |
-
|
8 |
-
function download_ngc_model() {
|
9 |
-
mkdir -p "$MODEL_DIR"
|
10 |
-
|
11 |
-
if [ ! -f "${MODEL_DIR}/${MODEL_ZIP}" ]; then
|
12 |
-
echo "Downloading ${MODEL_ZIP} ..."
|
13 |
-
wget --content-disposition -O ${MODEL_DIR}/${MODEL_ZIP} ${MODEL_URL} \
|
14 |
-
|| { echo "ERROR: Failed to download ${MODEL_ZIP} from NGC"; exit 1; }
|
15 |
-
fi
|
16 |
-
|
17 |
-
if [ ! -f "${MODEL_DIR}/${MODEL}" ]; then
|
18 |
-
echo "Extracting ${MODEL} ..."
|
19 |
-
unzip -qo ${MODEL_DIR}/${MODEL_ZIP} -d ${MODEL_DIR} \
|
20 |
-
|| { echo "ERROR: Failed to extract ${MODEL_ZIP}"; exit 1; }
|
21 |
-
|
22 |
-
echo "OK"
|
23 |
-
|
24 |
-
else
|
25 |
-
echo "${MODEL} already downloaded."
|
26 |
-
fi
|
27 |
-
|
28 |
-
}
|
29 |
-
|
30 |
-
for MODEL_NAME in $MODEL_NAMES
|
31 |
-
do
|
32 |
-
case $MODEL_NAME in
|
33 |
-
"fastpitch")
|
34 |
-
MODEL_DIR="pretrained_models/fastpitch"
|
35 |
-
MODEL_ZIP="fastpitch_pyt_fp32_ckpt_v1_1_21.05.0.zip"
|
36 |
-
MODEL="nvidia_fastpitch_210824.pt"
|
37 |
-
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/fastpitch_pyt_fp32_ckpt_v1_1/versions/21.05.0/zip"
|
38 |
-
;;
|
39 |
-
"hifigan")
|
40 |
-
MODEL_DIR="pretrained_models/hifigan"
|
41 |
-
MODEL_ZIP="hifigan__pyt_ckpt_ds-ljs22khz_21.08.0_amp.zip"
|
42 |
-
MODEL="hifigan_gen_checkpoint_6500.pt"
|
43 |
-
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/dle/hifigan__pyt_ckpt_ds-ljs22khz/versions/21.08.0_amp/zip"
|
44 |
-
;;
|
45 |
-
"hifigan-finetuned-fastpitch")
|
46 |
-
MODEL_DIR="pretrained_models/hifigan"
|
47 |
-
MODEL_ZIP="hifigan__pyt_ckpt_mode-finetune_ds-ljs22khz_21.08.0_amp.zip"
|
48 |
-
MODEL="hifigan_gen_checkpoint_10000_ft.pt"
|
49 |
-
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/dle/hifigan__pyt_ckpt_mode-finetune_ds-ljs22khz/versions/21.08.0_amp/zip"
|
50 |
-
;;
|
51 |
-
"waveglow")
|
52 |
-
MODEL_DIR="pretrained_models/waveglow"
|
53 |
-
MODEL_ZIP="waveglow_ckpt_amp_256_20.01.0.zip"
|
54 |
-
MODEL="nvidia_waveglow256pyt_fp16.pt"
|
55 |
-
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ckpt_amp_256/versions/20.01.0/zip"
|
56 |
-
;;
|
57 |
-
*)
|
58 |
-
echo "Unrecognized model: ${MODEL_NAME}"
|
59 |
-
exit 2
|
60 |
-
;;
|
61 |
-
esac
|
62 |
-
download_ngc_model "$MODEL_DIR" "$MODEL_ZIP" "$MODEL" "$MODEL_URL"
|
63 |
-
done
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/inference_benchmark.sh
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
#!/usr/bin/env bash
|
2 |
-
|
3 |
-
set -a
|
4 |
-
|
5 |
-
: ${FILELIST:="phrases/benchmark_8_128.tsv"}
|
6 |
-
: ${OUTPUT_DIR:="./output/audio_$(basename ${FILELIST} .tsv)"}
|
7 |
-
: ${TORCHSCRIPT:=true}
|
8 |
-
: ${BS_SEQUENCE:="1 4 8"}
|
9 |
-
: ${WARMUP:=64}
|
10 |
-
: ${REPEATS:=500}
|
11 |
-
: ${AMP:=false}
|
12 |
-
|
13 |
-
for BATCH_SIZE in $BS_SEQUENCE ; do
|
14 |
-
LOG_FILE="$OUTPUT_DIR"/perf-infer_amp-${AMP}_bs${BATCH_SIZE}.json
|
15 |
-
bash scripts/inference_example.sh "$@"
|
16 |
-
done
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/inference_example.sh
DELETED
@@ -1,78 +0,0 @@
|
|
1 |
-
#!/usr/bin/env bash
|
2 |
-
|
3 |
-
export CUDNN_V8_API_ENABLED=1 # Keep the flag for older containers
|
4 |
-
export TORCH_CUDNN_V8_API_ENABLED=1
|
5 |
-
|
6 |
-
: ${DATASET_DIR:="sander_splits"}
|
7 |
-
: ${BATCH_SIZE:=1}
|
8 |
-
: ${FILELIST:="phrases/giehttjit.txt"}
|
9 |
-
: ${AMP:=false}
|
10 |
-
: ${TORCHSCRIPT:=true}
|
11 |
-
: ${WARMUP:=0}
|
12 |
-
: ${REPEATS:=1}
|
13 |
-
: ${CPU:=false}
|
14 |
-
: ${PHONE:=true}
|
15 |
-
|
16 |
-
# Paths to pre-trained models downloadable from NVIDIA NGC (LJSpeech-1.1)
|
17 |
-
FASTPITCH_LJ="output/FastPitch_checkpoint_660.pt"
|
18 |
-
HIFIGAN_LJ="pretrained_models/hifigan/hifigan_gen_checkpoint_10000_ft.pt"
|
19 |
-
WAVEGLOW_LJ="pretrained_models/waveglow/nvidia_waveglow256pyt_fp16.pt"
|
20 |
-
|
21 |
-
# Mel-spectrogram generator (optional; can synthesize from ground-truth spectrograms)
|
22 |
-
: ${FASTPITCH=$FASTPITCH_LJ}
|
23 |
-
|
24 |
-
# Vocoder (set only one)
|
25 |
-
#: ${HIFIGAN=$HIFIGAN_LJ}
|
26 |
-
: ${WAVEGLOW=$WAVEGLOW_LJ}
|
27 |
-
|
28 |
-
[[ "$FASTPITCH" == "$FASTPITCH_LJ" && ! -f "$FASTPITCH" ]] && { echo "Downloading $FASTPITCH from NGC..."; bash scripts/download_models.sh fastpitch; }
|
29 |
-
[[ "$WAVEGLOW" == "$WAVEGLOW_LJ" && ! -f "$WAVEGLOW" ]] && { echo "Downloading $WAVEGLOW from NGC..."; bash scripts/download_models.sh waveglow; }
|
30 |
-
[[ "$HIFIGAN" == "$HIFIGAN_LJ" && ! -f "$HIFIGAN" ]] && { echo "Downloading $HIFIGAN from NGC..."; bash scripts/download_models.sh hifigan-finetuned-fastpitch; }
|
31 |
-
|
32 |
-
if [[ "$HIFIGAN" == "$HIFIGAN_LJ" && "$FASTPITCH" != "$FASTPITCH_LJ" ]]; then
|
33 |
-
echo -e "\nNOTE: Using HiFi-GAN checkpoint trained for the LJSpeech-1.1 dataset."
|
34 |
-
echo -e "NOTE: If you're using a different dataset, consider training a new HiFi-GAN model or switch to WaveGlow."
|
35 |
-
echo -e "NOTE: See $0 for details.\n"
|
36 |
-
fi
|
37 |
-
|
38 |
-
# Synthesis
|
39 |
-
: ${SPEAKER:=0}
|
40 |
-
: ${DENOISING:=0.01}
|
41 |
-
|
42 |
-
if [ ! -n "$OUTPUT_DIR" ]; then
|
43 |
-
OUTPUT_DIR="./output/audio_$(basename ${FILELIST} .tsv)"
|
44 |
-
[ "$AMP" = true ] && OUTPUT_DIR+="_fp16"
|
45 |
-
[ "$AMP" = false ] && OUTPUT_DIR+="_fp32"
|
46 |
-
[ -n "$FASTPITCH" ] && OUTPUT_DIR+="_fastpitch"
|
47 |
-
[ ! -n "$FASTPITCH" ] && OUTPUT_DIR+="_gt-mel"
|
48 |
-
[ -n "$WAVEGLOW" ] && OUTPUT_DIR+="_waveglow"
|
49 |
-
[ -n "$HIFIGAN" ] && OUTPUT_DIR+="_hifigan"
|
50 |
-
OUTPUT_DIR+="_denoise-"${DENOISING}
|
51 |
-
fi
|
52 |
-
: ${LOG_FILE:="$OUTPUT_DIR/nvlog_infer.json"}
|
53 |
-
mkdir -p "$OUTPUT_DIR"
|
54 |
-
|
55 |
-
echo -e "\nAMP=$AMP, batch_size=$BATCH_SIZE\n"
|
56 |
-
|
57 |
-
ARGS=""
|
58 |
-
ARGS+=" --cuda"
|
59 |
-
# ARGS+=" --cudnn-benchmark" # Enable for benchmarking or long operation
|
60 |
-
ARGS+=" --dataset-path $DATASET_DIR"
|
61 |
-
ARGS+=" -i $FILELIST"
|
62 |
-
ARGS+=" -o $OUTPUT_DIR"
|
63 |
-
ARGS+=" --log-file $LOG_FILE"
|
64 |
-
ARGS+=" --batch-size $BATCH_SIZE"
|
65 |
-
ARGS+=" --denoising-strength $DENOISING"
|
66 |
-
ARGS+=" --warmup-steps $WARMUP"
|
67 |
-
ARGS+=" --repeats $REPEATS"
|
68 |
-
ARGS+=" --speaker $SPEAKER"
|
69 |
-
[ "$CPU" = false ] && ARGS+=" --cuda"
|
70 |
-
[ "$CPU" = false ] && ARGS+=" --cudnn-benchmark"
|
71 |
-
[ "$AMP" = true ] && ARGS+=" --amp"
|
72 |
-
[ "$TORCHSCRIPT" = true ] && ARGS+=" --torchscript"
|
73 |
-
[ -n "$HIFIGAN" ] && ARGS+=" --hifigan $HIFIGAN"
|
74 |
-
[ -n "$WAVEGLOW" ] && ARGS+=" --waveglow $WAVEGLOW"
|
75 |
-
[ -n "$FASTPITCH" ] && ARGS+=" --fastpitch $FASTPITCH"
|
76 |
-
[ "$PHONE" = true ] && ARGS+=" --p-arpabet 1.0"
|
77 |
-
|
78 |
-
python inference.py $ARGS "$@"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/prepare_dataset.sh
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
#!/usr/bin/env bash
|
2 |
-
|
3 |
-
set -e
|
4 |
-
|
5 |
-
: ${DATA_DIR:=ALL_SAMI}
|
6 |
-
: ${ARGS="--extract-mels"}
|
7 |
-
|
8 |
-
python prepare_dataset.py \
|
9 |
-
--wav-text-filelists filelists/all_sami_filelist_shuf_200_train.txt \
|
10 |
-
--n-workers 8 \
|
11 |
-
--batch-size 1 \
|
12 |
-
--dataset-path $DATA_DIR \
|
13 |
-
--extract-pitch \
|
14 |
-
--f0-method pyin \
|
15 |
-
--pitch_mean 150\
|
16 |
-
--pitch_std 40\
|
17 |
-
--n-speakers 10 \
|
18 |
-
--n-languages 3 \
|
19 |
-
$ARGS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/train.sh
DELETED
@@ -1,100 +0,0 @@
|
|
1 |
-
#!/usr/bin/env bash
|
2 |
-
|
3 |
-
export OMP_NUM_THREADS=1
|
4 |
-
|
5 |
-
: ${NUM_GPUS:=8}
|
6 |
-
: ${BATCH_SIZE:=16}
|
7 |
-
: ${GRAD_ACCUMULATION:=2}
|
8 |
-
: ${OUTPUT_DIR:="./output"}
|
9 |
-
: ${LOG_FILE:=$OUTPUT_DIR/nvlog.json}
|
10 |
-
: ${DATASET_PATH:=LJSpeech-1.1}
|
11 |
-
: ${TRAIN_FILELIST:=filelists/ljs_audio_pitch_text_train_v3.txt}
|
12 |
-
: ${VAL_FILELIST:=filelists/ljs_audio_pitch_text_val.txt}
|
13 |
-
: ${AMP:=false}
|
14 |
-
: ${SEED:=""}
|
15 |
-
|
16 |
-
: ${LEARNING_RATE:=0.1}
|
17 |
-
|
18 |
-
# Adjust these when the amount of data changes
|
19 |
-
: ${EPOCHS:=1000}
|
20 |
-
: ${EPOCHS_PER_CHECKPOINT:=20}
|
21 |
-
: ${WARMUP_STEPS:=1000}
|
22 |
-
: ${KL_LOSS_WARMUP:=100}
|
23 |
-
|
24 |
-
# Train a mixed phoneme/grapheme model
|
25 |
-
: ${PHONE:=true}
|
26 |
-
# Enable energy conditioning
|
27 |
-
: ${ENERGY:=true}
|
28 |
-
: ${TEXT_CLEANERS:=english_cleaners_v2}
|
29 |
-
# Add dummy space prefix/suffix is audio is not precisely trimmed
|
30 |
-
: ${APPEND_SPACES:=false}
|
31 |
-
|
32 |
-
: ${LOAD_PITCH_FROM_DISK:=true}
|
33 |
-
: ${LOAD_MEL_FROM_DISK:=false}
|
34 |
-
|
35 |
-
# For multispeaker models, add speaker ID = {0, 1, ...} as the last filelist column
|
36 |
-
: ${NSPEAKERS:=1}
|
37 |
-
: ${SAMPLING_RATE:=22050}
|
38 |
-
|
39 |
-
# Adjust env variables to maintain the global batch size: NUM_GPUS x BATCH_SIZE x GRAD_ACCUMULATION = 256.
|
40 |
-
GBS=$(($NUM_GPUS * $BATCH_SIZE * $GRAD_ACCUMULATION))
|
41 |
-
[ $GBS -ne 256 ] && echo -e "\nWARNING: Global batch size changed from 256 to ${GBS}."
|
42 |
-
echo -e "\nAMP=$AMP, ${NUM_GPUS}x${BATCH_SIZE}x${GRAD_ACCUMULATION}" \
|
43 |
-
"(global batch size ${GBS})\n"
|
44 |
-
|
45 |
-
ARGS=""
|
46 |
-
ARGS+=" --cuda"
|
47 |
-
ARGS+=" -o $OUTPUT_DIR"
|
48 |
-
ARGS+=" --log-file $LOG_FILE"
|
49 |
-
ARGS+=" --dataset-path $DATASET_PATH"
|
50 |
-
ARGS+=" --training-files $TRAIN_FILELIST"
|
51 |
-
ARGS+=" --validation-files $VAL_FILELIST"
|
52 |
-
ARGS+=" -bs $BATCH_SIZE"
|
53 |
-
ARGS+=" --grad-accumulation $GRAD_ACCUMULATION"
|
54 |
-
ARGS+=" --optimizer lamb"
|
55 |
-
ARGS+=" --epochs $EPOCHS"
|
56 |
-
ARGS+=" --epochs-per-checkpoint $EPOCHS_PER_CHECKPOINT"
|
57 |
-
ARGS+=" --resume"
|
58 |
-
ARGS+=" --warmup-steps $WARMUP_STEPS"
|
59 |
-
ARGS+=" -lr $LEARNING_RATE"
|
60 |
-
ARGS+=" --weight-decay 1e-6"
|
61 |
-
ARGS+=" --grad-clip-thresh 1000.0"
|
62 |
-
ARGS+=" --dur-predictor-loss-scale 0.1"
|
63 |
-
ARGS+=" --pitch-predictor-loss-scale 0.1"
|
64 |
-
ARGS+=" --trainloader-repeats 100"
|
65 |
-
ARGS+=" --validation-freq 10"
|
66 |
-
|
67 |
-
# Autoalign & new features
|
68 |
-
ARGS+=" --kl-loss-start-epoch 0"
|
69 |
-
ARGS+=" --kl-loss-warmup-epochs $KL_LOSS_WARMUP"
|
70 |
-
ARGS+=" --text-cleaners $TEXT_CLEANERS"
|
71 |
-
ARGS+=" --n-speakers $NSPEAKERS"
|
72 |
-
|
73 |
-
[ "$AMP" = "true" ] && ARGS+=" --amp"
|
74 |
-
[ "$PHONE" = "true" ] && ARGS+=" --p-arpabet 1.0"
|
75 |
-
[ "$ENERGY" = "true" ] && ARGS+=" --energy-conditioning"
|
76 |
-
[ "$SEED" != "" ] && ARGS+=" --seed $SEED"
|
77 |
-
[ "$LOAD_MEL_FROM_DISK" = true ] && ARGS+=" --load-mel-from-disk"
|
78 |
-
[ "$LOAD_PITCH_FROM_DISK" = true ] && ARGS+=" --load-pitch-from-disk"
|
79 |
-
[ "$PITCH_ONLINE_DIR" != "" ] && ARGS+=" --pitch-online-dir $PITCH_ONLINE_DIR" # e.g., /dev/shm/pitch
|
80 |
-
[ "$PITCH_ONLINE_METHOD" != "" ] && ARGS+=" --pitch-online-method $PITCH_ONLINE_METHOD"
|
81 |
-
[ "$APPEND_SPACES" = true ] && ARGS+=" --prepend-space-to-text"
|
82 |
-
[ "$APPEND_SPACES" = true ] && ARGS+=" --append-space-to-text"
|
83 |
-
|
84 |
-
if [ "$SAMPLING_RATE" == "44100" ]; then
|
85 |
-
ARGS+=" --sampling-rate 44100"
|
86 |
-
ARGS+=" --filter-length 2048"
|
87 |
-
ARGS+=" --hop-length 512"
|
88 |
-
ARGS+=" --win-length 2048"
|
89 |
-
ARGS+=" --mel-fmin 0.0"
|
90 |
-
ARGS+=" --mel-fmax 22050.0"
|
91 |
-
|
92 |
-
elif [ "$SAMPLING_RATE" != "22050" ]; then
|
93 |
-
echo "Unknown sampling rate $SAMPLING_RATE"
|
94 |
-
exit 1
|
95 |
-
fi
|
96 |
-
|
97 |
-
mkdir -p "$OUTPUT_DIR"
|
98 |
-
|
99 |
-
: ${DISTRIBUTED:="-m torch.distributed.launch --nproc_per_node $NUM_GPUS"}
|
100 |
-
python $DISTRIBUTED train.py $ARGS "$@"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/train_multilang.sh
DELETED
@@ -1,110 +0,0 @@
|
|
1 |
-
#!/usr/bin/env bash
|
2 |
-
|
3 |
-
export OMP_NUM_THREADS=1
|
4 |
-
|
5 |
-
: ${NUM_GPUS:=1}
|
6 |
-
: ${BATCH_SIZE:=1}
|
7 |
-
: ${GRAD_ACCUMULATION:=32}
|
8 |
-
: ${OUTPUT_DIR:="./output_multilang"}
|
9 |
-
: ${LOG_FILE:=$OUTPUT_DIR/nvlog.json}
|
10 |
-
: ${DATASET_PATH:=ALL_SAMI}
|
11 |
-
#: ${DATASET_PATH:=mikal_urheim}
|
12 |
-
#: ${TRAIN_FILELIST:=filelists/smj_sander_text_noshorts_shuff_pitch.txt}
|
13 |
-
#: ${TRAIN_FILELIST:=filelists/mikal_urheim_pitch_shuf.txt}
|
14 |
-
: ${TRAIN_FILELIST:=filelists/all_sami_filelist_shuf_200_train.txt}
|
15 |
-
#: ${VAL_FILELIST:=filelists/smj_sander_text_noshorts_shuff_val_pitch.txt}
|
16 |
-
#: ${VAL_FILELIST:=filelists/mikal_urheim_pitch_shuf_val.txt}
|
17 |
-
: ${VAL_FILELIST:=filelists/all_sami_filelist_shuf_200_val.txt}
|
18 |
-
: ${AMP:=false}
|
19 |
-
: ${SEED:=""}
|
20 |
-
|
21 |
-
: ${LEARNING_RATE:=0.1}
|
22 |
-
|
23 |
-
# Adjust these when the amount of data changes
|
24 |
-
: ${EPOCHS:=1000}
|
25 |
-
: ${EPOCHS_PER_CHECKPOINT:=10}
|
26 |
-
: ${WARMUP_STEPS:=1000}
|
27 |
-
: ${KL_LOSS_WARMUP:=100}
|
28 |
-
|
29 |
-
# Train a mixed phoneme/grapheme model
|
30 |
-
: ${PHONE:=false}
|
31 |
-
# Enable energy conditioning
|
32 |
-
: ${ENERGY:=true}
|
33 |
-
: ${TEXT_CLEANERS:=basic_cleaners}
|
34 |
-
: ${SYMBOL_SET:=all_sami}
|
35 |
-
# Add dummy space prefix/suffix is audio is not precisely trimmed
|
36 |
-
: ${APPEND_SPACES:=false}
|
37 |
-
|
38 |
-
: ${LOAD_PITCH_FROM_DISK:=true} # was true
|
39 |
-
: ${LOAD_MEL_FROM_DISK:=false}
|
40 |
-
|
41 |
-
# For multispeaker models, add speaker ID = {0, 1, ...} as the last filelist column
|
42 |
-
: ${NSPEAKERS:=10} # 10
|
43 |
-
: ${NLANGUAGES:=3} # 3
|
44 |
-
: ${SAMPLING_RATE:=22050}
|
45 |
-
|
46 |
-
# Adjust env variables to maintain the global batch size: NUM_GPUS x BATCH_SIZE x GRAD_ACCUMULATION = 256.
|
47 |
-
GBS=$(($NUM_GPUS * $BATCH_SIZE * $GRAD_ACCUMULATION))
|
48 |
-
[ $GBS -ne 256 ] && echo -e "\nWARNING: Global batch size changed from 256 to ${GBS}."
|
49 |
-
echo -e "\nAMP=$AMP, ${NUM_GPUS}x${BATCH_SIZE}x${GRAD_ACCUMULATION}" \
|
50 |
-
"(global batch size ${GBS})\n"
|
51 |
-
|
52 |
-
ARGS=""
|
53 |
-
ARGS+=" --cuda"
|
54 |
-
ARGS+=" -o $OUTPUT_DIR"
|
55 |
-
ARGS+=" --log-file $LOG_FILE"
|
56 |
-
ARGS+=" --dataset-path $DATASET_PATH"
|
57 |
-
ARGS+=" --training-files $TRAIN_FILELIST"
|
58 |
-
ARGS+=" --validation-files $VAL_FILELIST"
|
59 |
-
ARGS+=" -bs $BATCH_SIZE"
|
60 |
-
ARGS+=" --grad-accumulation $GRAD_ACCUMULATION"
|
61 |
-
ARGS+=" --optimizer lamb" #adam
|
62 |
-
ARGS+=" --epochs $EPOCHS"
|
63 |
-
ARGS+=" --epochs-per-checkpoint $EPOCHS_PER_CHECKPOINT"
|
64 |
-
ARGS+=" --resume"
|
65 |
-
ARGS+=" --warmup-steps $WARMUP_STEPS"
|
66 |
-
ARGS+=" -lr $LEARNING_RATE"
|
67 |
-
ARGS+=" --weight-decay 1e-6"
|
68 |
-
ARGS+=" --grad-clip-thresh 1000.0"
|
69 |
-
ARGS+=" --dur-predictor-loss-scale 0.1"
|
70 |
-
ARGS+=" --pitch-predictor-loss-scale 0.1"
|
71 |
-
ARGS+=" --trainloader-repeats 100"
|
72 |
-
ARGS+=" --validation-freq 1" #10
|
73 |
-
|
74 |
-
# Autoalign & new features
|
75 |
-
ARGS+=" --kl-loss-start-epoch 0"
|
76 |
-
ARGS+=" --kl-loss-warmup-epochs $KL_LOSS_WARMUP"
|
77 |
-
ARGS+=" --text-cleaners $TEXT_CLEANERS"
|
78 |
-
ARGS+=" --n-speakers $NSPEAKERS"
|
79 |
-
ARGS+=" --n-languages $NLANGUAGES"
|
80 |
-
ARGS+=" --symbol-set $SYMBOL_SET"
|
81 |
-
|
82 |
-
[ "$AMP" = "true" ] && ARGS+=" --amp"
|
83 |
-
[ "$PHONE" = "true" ] && ARGS+=" --p-arpabet 1.0"
|
84 |
-
[ "$ENERGY" = "true" ] && ARGS+=" --energy-conditioning"
|
85 |
-
[ "$SEED" != "" ] && ARGS+=" --seed $SEED"
|
86 |
-
[ "$LOAD_MEL_FROM_DISK" = true ] && ARGS+=" --load-mel-from-disk"
|
87 |
-
[ "$LOAD_PITCH_FROM_DISK" = true ] && ARGS+=" --load-pitch-from-disk"
|
88 |
-
[ "$PITCH_ONLINE_DIR" != "" ] && ARGS+=" --pitch-online-dir $PITCH_ONLINE_DIR" # e.g., /dev/shm/pitch
|
89 |
-
[ "$PITCH_ONLINE_METHOD" != "" ] && ARGS+=" --pitch-online-method $PITCH_ONLINE_METHOD"
|
90 |
-
[ "$APPEND_SPACES" = true ] && ARGS+=" --prepend-space-to-text"
|
91 |
-
[ "$APPEND_SPACES" = true ] && ARGS+=" --append-space-to-text"
|
92 |
-
|
93 |
-
if [ "$SAMPLING_RATE" == "44100" ]; then
|
94 |
-
ARGS+=" --sampling-rate 44100"
|
95 |
-
ARGS+=" --filter-length 2048"
|
96 |
-
ARGS+=" --hop-length 512"
|
97 |
-
ARGS+=" --win-length 2048"
|
98 |
-
ARGS+=" --mel-fmin 0.0"
|
99 |
-
ARGS+=" --mel-fmax 22050.0"
|
100 |
-
|
101 |
-
elif [ "$SAMPLING_RATE" != "22050" ]; then
|
102 |
-
echo "Unknown sampling rate $SAMPLING_RATE"
|
103 |
-
exit 1
|
104 |
-
fi
|
105 |
-
|
106 |
-
mkdir -p "$OUTPUT_DIR"
|
107 |
-
|
108 |
-
: ${DISTRIBUTED:="-m torch.distributed.launch --nproc_per_node $NUM_GPUS"}
|
109 |
-
#python $DISTRIBUTED train.py $ARGS "$@"
|
110 |
-
python train_1_with_plot_multilang.py $ARGS "$@"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_1_with_plot_multilang.py
DELETED
@@ -1,593 +0,0 @@
|
|
1 |
-
# *****************************************************************************
|
2 |
-
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
-
#
|
4 |
-
# Redistribution and use in source and binary forms, with or without
|
5 |
-
# modification, are permitted provided that the following conditions are met:
|
6 |
-
# * Redistributions of source code must retain the above copyright
|
7 |
-
# notice, this list of conditions and the following disclaimer.
|
8 |
-
# * Redistributions in binary form must reproduce the above copyright
|
9 |
-
# notice, this list of conditions and the following disclaimer in the
|
10 |
-
# documentation and/or other materials provided with the distribution.
|
11 |
-
# * Neither the name of the NVIDIA CORPORATION nor the
|
12 |
-
# names of its contributors may be used to endorse or promote products
|
13 |
-
# derived from this software without specific prior written permission.
|
14 |
-
#
|
15 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
16 |
-
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
17 |
-
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
-
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
19 |
-
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
-
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
-
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
22 |
-
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
-
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
-
#
|
26 |
-
# *****************************************************************************
|
27 |
-
|
28 |
-
import argparse
|
29 |
-
import copy
|
30 |
-
import os
|
31 |
-
import time
|
32 |
-
from collections import defaultdict, OrderedDict
|
33 |
-
from itertools import cycle
|
34 |
-
|
35 |
-
import numpy as np
|
36 |
-
import torch
|
37 |
-
import torch.distributed as dist
|
38 |
-
import amp_C
|
39 |
-
from apex.optimizers import FusedAdam, FusedLAMB
|
40 |
-
from torch.nn.parallel import DistributedDataParallel
|
41 |
-
from torch.utils.data import DataLoader
|
42 |
-
from torch.utils.data.distributed import DistributedSampler
|
43 |
-
|
44 |
-
import common.tb_dllogger as logger
|
45 |
-
import models
|
46 |
-
from common.tb_dllogger import log
|
47 |
-
from common.repeated_dataloader import (RepeatedDataLoader,
|
48 |
-
RepeatedDistributedSampler)
|
49 |
-
from common.text import cmudict
|
50 |
-
from common.utils import BenchmarkStats, Checkpointer, prepare_tmp
|
51 |
-
from fastpitch.attn_loss_function import AttentionBinarizationLoss
|
52 |
-
from fastpitch.data_function import batch_to_gpu, TTSCollate, TTSDataset
|
53 |
-
from fastpitch.loss_function import FastPitchLoss
|
54 |
-
|
55 |
-
import matplotlib.pyplot as plt
|
56 |
-
|
57 |
-
def parse_args(parser):
|
58 |
-
parser.add_argument('-o', '--output', type=str, required=True,
|
59 |
-
help='Directory to save checkpoints')
|
60 |
-
parser.add_argument('-d', '--dataset-path', type=str, default='./',
|
61 |
-
help='Path to dataset')
|
62 |
-
parser.add_argument('--log-file', type=str, default=None,
|
63 |
-
help='Path to a DLLogger log file')
|
64 |
-
|
65 |
-
train = parser.add_argument_group('training setup')
|
66 |
-
train.add_argument('--epochs', type=int, required=True,
|
67 |
-
help='Number of total epochs to run')
|
68 |
-
train.add_argument('--epochs-per-checkpoint', type=int, default=50,
|
69 |
-
help='Number of epochs per checkpoint')
|
70 |
-
train.add_argument('--checkpoint-path', type=str, default=None,
|
71 |
-
help='Checkpoint path to resume training')
|
72 |
-
train.add_argument('--keep-milestones', default=list(range(100, 1000, 100)),
|
73 |
-
type=int, nargs='+',
|
74 |
-
help='Milestone checkpoints to keep from removing')
|
75 |
-
train.add_argument('--resume', action='store_true',
|
76 |
-
help='Resume training from the last checkpoint')
|
77 |
-
train.add_argument('--seed', type=int, default=1234,
|
78 |
-
help='Seed for PyTorch random number generators')
|
79 |
-
train.add_argument('--amp', action='store_true',
|
80 |
-
help='Enable AMP')
|
81 |
-
train.add_argument('--cuda', action='store_true',
|
82 |
-
help='Run on GPU using CUDA')
|
83 |
-
train.add_argument('--cudnn-benchmark', action='store_true',
|
84 |
-
help='Enable cudnn benchmark mode')
|
85 |
-
train.add_argument('--ema-decay', type=float, default=0,
|
86 |
-
help='Discounting factor for training weights EMA')
|
87 |
-
train.add_argument('--grad-accumulation', type=int, default=1,
|
88 |
-
help='Training steps to accumulate gradients for')
|
89 |
-
train.add_argument('--kl-loss-start-epoch', type=int, default=250,
|
90 |
-
help='Start adding the hard attention loss term')
|
91 |
-
train.add_argument('--kl-loss-warmup-epochs', type=int, default=100,
|
92 |
-
help='Gradually increase the hard attention loss term')
|
93 |
-
train.add_argument('--kl-loss-weight', type=float, default=1.0,
|
94 |
-
help='Gradually increase the hard attention loss term')
|
95 |
-
train.add_argument('--benchmark-epochs-num', type=int, default=20,
|
96 |
-
help='Number of epochs for calculating final stats')
|
97 |
-
train.add_argument('--validation-freq', type=int, default=1,
|
98 |
-
help='Validate every N epochs to use less compute')
|
99 |
-
|
100 |
-
opt = parser.add_argument_group('optimization setup')
|
101 |
-
opt.add_argument('--optimizer', type=str, default='lamb',
|
102 |
-
help='Optimization algorithm')
|
103 |
-
opt.add_argument('-lr', '--learning-rate', type=float, required=True,
|
104 |
-
help='Learing rate')
|
105 |
-
opt.add_argument('--weight-decay', default=1e-6, type=float,
|
106 |
-
help='Weight decay')
|
107 |
-
opt.add_argument('--grad-clip-thresh', default=1000.0, type=float,
|
108 |
-
help='Clip threshold for gradients')
|
109 |
-
opt.add_argument('-bs', '--batch-size', type=int, required=True,
|
110 |
-
help='Batch size per GPU')
|
111 |
-
opt.add_argument('--warmup-steps', type=int, default=1000,
|
112 |
-
help='Number of steps for lr warmup')
|
113 |
-
opt.add_argument('--dur-predictor-loss-scale', type=float,
|
114 |
-
default=1.0, help='Rescale duration predictor loss')
|
115 |
-
opt.add_argument('--pitch-predictor-loss-scale', type=float,
|
116 |
-
default=1.0, help='Rescale pitch predictor loss')
|
117 |
-
opt.add_argument('--attn-loss-scale', type=float,
|
118 |
-
default=1.0, help='Rescale alignment loss')
|
119 |
-
|
120 |
-
data = parser.add_argument_group('dataset parameters')
|
121 |
-
data.add_argument('--training-files', type=str, nargs='*', required=True,
|
122 |
-
help='Paths to training filelists.')
|
123 |
-
data.add_argument('--validation-files', type=str, nargs='*',
|
124 |
-
required=True, help='Paths to validation filelists')
|
125 |
-
data.add_argument('--text-cleaners', nargs='*',
|
126 |
-
default=['english_cleaners'], type=str,
|
127 |
-
help='Type of text cleaners for input text')
|
128 |
-
data.add_argument('--symbol-set', type=str, default='english_basic',
|
129 |
-
help='Define symbol set for input text')
|
130 |
-
data.add_argument('--p-arpabet', type=float, default=0.0,
|
131 |
-
help='Probability of using arpabets instead of graphemes '
|
132 |
-
'for each word; set 0 for pure grapheme training')
|
133 |
-
data.add_argument('--heteronyms-path', type=str, default='cmudict/heteronyms',
|
134 |
-
help='Path to the list of heteronyms')
|
135 |
-
data.add_argument('--cmudict-path', type=str, default='cmudict/cmudict-0.7b',
|
136 |
-
help='Path to the pronouncing dictionary')
|
137 |
-
data.add_argument('--prepend-space-to-text', action='store_true',
|
138 |
-
help='Capture leading silence with a space token')
|
139 |
-
data.add_argument('--append-space-to-text', action='store_true',
|
140 |
-
help='Capture trailing silence with a space token')
|
141 |
-
data.add_argument('--num-workers', type=int, default=2, # 6
|
142 |
-
help='Subprocesses for train and val DataLoaders')
|
143 |
-
data.add_argument('--trainloader-repeats', type=int, default=100,
|
144 |
-
help='Repeats the dataset to prolong epochs')
|
145 |
-
|
146 |
-
cond = parser.add_argument_group('data for conditioning')
|
147 |
-
cond.add_argument('--n-speakers', type=int, default=1,
|
148 |
-
help='Number of speakers in the dataset. '
|
149 |
-
'n_speakers > 1 enables speaker embeddings')
|
150 |
-
# ANT: added language
|
151 |
-
cond.add_argument('--n-languages', type=int, default=1,
|
152 |
-
help='Number of languages in the dataset. '
|
153 |
-
'n_languages > 1 enables language embeddings')
|
154 |
-
|
155 |
-
cond.add_argument('--load-pitch-from-disk', action='store_true',
|
156 |
-
help='Use pitch cached on disk with prepare_dataset.py')
|
157 |
-
cond.add_argument('--pitch-online-method', default='pyin',
|
158 |
-
choices=['pyin'],
|
159 |
-
help='Calculate pitch on the fly during trainig')
|
160 |
-
cond.add_argument('--pitch-online-dir', type=str, default=None,
|
161 |
-
help='A directory for storing pitch calculated on-line')
|
162 |
-
cond.add_argument('--pitch-mean', type=float, default=125.626816, #default=214.72203,
|
163 |
-
help='Normalization value for pitch')
|
164 |
-
cond.add_argument('--pitch-std', type=float, default=37.52, #default=65.72038,
|
165 |
-
help='Normalization value for pitch')
|
166 |
-
cond.add_argument('--load-mel-from-disk', action='store_true',
|
167 |
-
help='Use mel-spectrograms cache on the disk') # XXX
|
168 |
-
|
169 |
-
audio = parser.add_argument_group('audio parameters')
|
170 |
-
audio.add_argument('--max-wav-value', default=32768.0, type=float,
|
171 |
-
help='Maximum audiowave value')
|
172 |
-
audio.add_argument('--sampling-rate', default=22050, type=int,
|
173 |
-
help='Sampling rate')
|
174 |
-
audio.add_argument('--filter-length', default=1024, type=int,
|
175 |
-
help='Filter length')
|
176 |
-
audio.add_argument('--hop-length', default=256, type=int,
|
177 |
-
help='Hop (stride) length')
|
178 |
-
audio.add_argument('--win-length', default=1024, type=int,
|
179 |
-
help='Window length')
|
180 |
-
audio.add_argument('--mel-fmin', default=0.0, type=float,
|
181 |
-
help='Minimum mel frequency')
|
182 |
-
audio.add_argument('--mel-fmax', default=8000.0, type=float,
|
183 |
-
help='Maximum mel frequency')
|
184 |
-
|
185 |
-
dist = parser.add_argument_group('distributed setup')
|
186 |
-
dist.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0),
|
187 |
-
help='Rank of the process for multiproc; do not set manually')
|
188 |
-
dist.add_argument('--world_size', type=int, default=os.getenv('WORLD_SIZE', 1),
|
189 |
-
help='Number of processes for multiproc; do not set manually')
|
190 |
-
return parser
|
191 |
-
|
192 |
-
|
193 |
-
def reduce_tensor(tensor, num_gpus):
|
194 |
-
rt = tensor.clone()
|
195 |
-
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
196 |
-
return rt.true_divide(num_gpus)
|
197 |
-
|
198 |
-
|
199 |
-
def init_distributed(args, world_size, rank):
|
200 |
-
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
|
201 |
-
print("Initializing distributed training")
|
202 |
-
|
203 |
-
# Set cuda device so everything is done on the right GPU.
|
204 |
-
torch.cuda.set_device(rank % torch.cuda.device_count())
|
205 |
-
|
206 |
-
# Initialize distributed communication
|
207 |
-
dist.init_process_group(backend=('nccl' if args.cuda else 'gloo'),
|
208 |
-
init_method='env://')
|
209 |
-
print("Done initializing distributed training")
|
210 |
-
|
211 |
-
|
212 |
-
def validate(model, epoch, total_iter, criterion, val_loader, distributed_run,
|
213 |
-
batch_to_gpu, local_rank, ema=False):
|
214 |
-
was_training = model.training
|
215 |
-
model.eval()
|
216 |
-
|
217 |
-
tik = time.perf_counter()
|
218 |
-
with torch.no_grad():
|
219 |
-
val_meta = defaultdict(float)
|
220 |
-
val_num_frames = 0
|
221 |
-
for i, batch in enumerate(val_loader):
|
222 |
-
x, y, num_frames = batch_to_gpu(batch)
|
223 |
-
y_pred = model(x)
|
224 |
-
loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum')
|
225 |
-
|
226 |
-
if distributed_run:
|
227 |
-
for k, v in meta.items():
|
228 |
-
val_meta[k] += reduce_tensor(v, 1)
|
229 |
-
val_num_frames += reduce_tensor(num_frames.data, 1).item()
|
230 |
-
else:
|
231 |
-
for k, v in meta.items():
|
232 |
-
val_meta[k] += v
|
233 |
-
val_num_frames += num_frames.item()
|
234 |
-
|
235 |
-
# NOTE: ugly patch to visualize the first utterance of the validation corpus.
|
236 |
-
# The goal is to determine if the training is progressing properly
|
237 |
-
if (i == 0) and (local_rank == 0) and (not ema):
|
238 |
-
# Plot some debug information
|
239 |
-
fig, axs = plt.subplots(2, 2, figsize=(21,14))
|
240 |
-
|
241 |
-
# - Mel-spectrogram
|
242 |
-
pred_mel = y_pred[0][0, :, :].cpu().detach().numpy().astype(np.float32).T
|
243 |
-
orig_mel = y[0][0, :, :].cpu().detach().numpy().astype(np.float32)
|
244 |
-
axs[0,0].imshow(orig_mel, aspect='auto', origin='lower', interpolation='nearest')
|
245 |
-
axs[1,0].imshow(pred_mel, aspect='auto', origin='lower', interpolation='nearest')
|
246 |
-
|
247 |
-
# Prosody
|
248 |
-
f0_pred = y_pred[4][0, :].cpu().detach().numpy().astype(np.float32)
|
249 |
-
f0_ori = y_pred[5][0, :].cpu().detach().numpy().astype(np.float32)
|
250 |
-
axs[1,1].plot(f0_ori)
|
251 |
-
axs[1,1].plot(f0_pred)
|
252 |
-
|
253 |
-
# # Duration
|
254 |
-
# att_pred = y_pred[2][0, :].cpu().detach().numpy().astype(np.float32)
|
255 |
-
# att_ori = x[7][0,:].cpu().detach().numpy().astype(np.float32)
|
256 |
-
# axs[0,1].imshow(att_ori, aspect='auto', origin='lower', interpolation='nearest')
|
257 |
-
|
258 |
-
if not os.path.exists("debug_epoch/"):
|
259 |
-
os.makedirs("debug_epoch_laila/")
|
260 |
-
|
261 |
-
fig.savefig(f'debug_epoch/{epoch:06d}.png', bbox_inches='tight')
|
262 |
-
|
263 |
-
val_meta = {k: v / len(val_loader.dataset) for k, v in val_meta.items()}
|
264 |
-
|
265 |
-
val_meta['took'] = time.perf_counter() - tik
|
266 |
-
|
267 |
-
log((epoch,) if epoch is not None else (), tb_total_steps=total_iter,
|
268 |
-
subset='val_ema' if ema else 'val',
|
269 |
-
data=OrderedDict([
|
270 |
-
('loss', val_meta['loss'].item()),
|
271 |
-
('mel_loss', val_meta['mel_loss'].item()),
|
272 |
-
('frames/s', val_num_frames / val_meta['took']),
|
273 |
-
('took', val_meta['took'])]),
|
274 |
-
)
|
275 |
-
|
276 |
-
if was_training:
|
277 |
-
model.train()
|
278 |
-
return val_meta
|
279 |
-
|
280 |
-
|
281 |
-
def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None):
|
282 |
-
if warmup_iters == 0:
|
283 |
-
scale = 1.0
|
284 |
-
elif total_iter > warmup_iters:
|
285 |
-
scale = 1. / (total_iter ** 0.5)
|
286 |
-
else:
|
287 |
-
scale = total_iter / (warmup_iters ** 1.5)
|
288 |
-
|
289 |
-
for param_group in opt.param_groups:
|
290 |
-
param_group['lr'] = learning_rate * scale
|
291 |
-
|
292 |
-
|
293 |
-
def apply_ema_decay(model, ema_model, decay):
|
294 |
-
if not decay:
|
295 |
-
return
|
296 |
-
st = model.state_dict()
|
297 |
-
add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
|
298 |
-
for k, v in ema_model.state_dict().items():
|
299 |
-
if add_module and not k.startswith('module.'):
|
300 |
-
k = 'module.' + k
|
301 |
-
v.copy_(decay * v + (1 - decay) * st[k])
|
302 |
-
|
303 |
-
|
304 |
-
def init_multi_tensor_ema(model, ema_model):
|
305 |
-
model_weights = list(model.state_dict().values())
|
306 |
-
ema_model_weights = list(ema_model.state_dict().values())
|
307 |
-
ema_overflow_buf = torch.cuda.IntTensor([0])
|
308 |
-
return model_weights, ema_model_weights, ema_overflow_buf
|
309 |
-
|
310 |
-
|
311 |
-
def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
|
312 |
-
amp_C.multi_tensor_axpby(
|
313 |
-
65536, overflow_buf, [ema_weights, model_weights, ema_weights],
|
314 |
-
decay, 1-decay, -1)
|
315 |
-
|
316 |
-
|
317 |
-
def main():
|
318 |
-
parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
|
319 |
-
allow_abbrev=False)
|
320 |
-
parser = parse_args(parser)
|
321 |
-
args, _ = parser.parse_known_args()
|
322 |
-
|
323 |
-
if args.p_arpabet > 0.0:
|
324 |
-
cmudict.initialize(args.cmudict_path, args.heteronyms_path)
|
325 |
-
|
326 |
-
distributed_run = args.world_size > 1
|
327 |
-
|
328 |
-
torch.manual_seed(args.seed + args.local_rank)
|
329 |
-
np.random.seed(args.seed + args.local_rank)
|
330 |
-
|
331 |
-
if args.local_rank == 0:
|
332 |
-
if not os.path.exists(args.output):
|
333 |
-
os.makedirs(args.output)
|
334 |
-
|
335 |
-
log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
|
336 |
-
tb_subsets = ['train', 'val']
|
337 |
-
if args.ema_decay > 0.0:
|
338 |
-
tb_subsets.append('val_ema')
|
339 |
-
|
340 |
-
logger.init(log_fpath, args.output, enabled=(args.local_rank == 0),
|
341 |
-
tb_subsets=tb_subsets)
|
342 |
-
logger.parameters(vars(args), tb_subset='train')
|
343 |
-
|
344 |
-
parser = models.parse_model_args('FastPitch', parser)
|
345 |
-
args, unk_args = parser.parse_known_args()
|
346 |
-
if len(unk_args) > 0:
|
347 |
-
raise ValueError(f'Invalid options {unk_args}')
|
348 |
-
|
349 |
-
torch.backends.cudnn.benchmark = args.cudnn_benchmark
|
350 |
-
|
351 |
-
if distributed_run:
|
352 |
-
init_distributed(args, args.world_size, args.local_rank)
|
353 |
-
else:
|
354 |
-
if args.trainloader_repeats > 1:
|
355 |
-
print('WARNING: Disabled --trainloader-repeats, supported only for'
|
356 |
-
' multi-GPU data loading.')
|
357 |
-
args.trainloader_repeats = 1
|
358 |
-
|
359 |
-
device = torch.device('cuda' if args.cuda else 'cpu')
|
360 |
-
model_config = models.get_model_config('FastPitch', args)
|
361 |
-
model = models.get_model('FastPitch', model_config, device)
|
362 |
-
|
363 |
-
attention_kl_loss = AttentionBinarizationLoss()
|
364 |
-
|
365 |
-
# Store pitch mean/std as params to translate from Hz during inference
|
366 |
-
model.pitch_mean[0] = args.pitch_mean
|
367 |
-
model.pitch_std[0] = args.pitch_std
|
368 |
-
|
369 |
-
kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9,
|
370 |
-
weight_decay=args.weight_decay)
|
371 |
-
if args.optimizer == 'adam':
|
372 |
-
optimizer = FusedAdam(model.parameters(), **kw)
|
373 |
-
# optimizer = torch.optim.Adam(model.parameters(), **kw)
|
374 |
-
elif args.optimizer == 'lamb':
|
375 |
-
|
376 |
-
optimizer = FusedLAMB(model.parameters(), **kw)
|
377 |
-
# optimizer = torch.optim.Adam(model.parameters(), **kw)
|
378 |
-
else:
|
379 |
-
raise ValueError
|
380 |
-
|
381 |
-
scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
382 |
-
|
383 |
-
if args.ema_decay > 0:
|
384 |
-
ema_model = copy.deepcopy(model)
|
385 |
-
else:
|
386 |
-
ema_model = None
|
387 |
-
|
388 |
-
if distributed_run:
|
389 |
-
model = DistributedDataParallel(
|
390 |
-
model, device_ids=[args.local_rank], output_device=args.local_rank,
|
391 |
-
find_unused_parameters=True)
|
392 |
-
|
393 |
-
train_state = {'epoch': 1, 'total_iter': 1}
|
394 |
-
checkpointer = Checkpointer(args.output, args.keep_milestones)
|
395 |
-
|
396 |
-
checkpointer.maybe_load(model, optimizer, scaler, train_state, args,
|
397 |
-
ema_model)
|
398 |
-
|
399 |
-
start_epoch = train_state['epoch']
|
400 |
-
total_iter = train_state['total_iter']
|
401 |
-
|
402 |
-
criterion = FastPitchLoss(
|
403 |
-
dur_predictor_loss_scale=args.dur_predictor_loss_scale,
|
404 |
-
pitch_predictor_loss_scale=args.pitch_predictor_loss_scale,
|
405 |
-
attn_loss_scale=args.attn_loss_scale)
|
406 |
-
|
407 |
-
collate_fn = TTSCollate()
|
408 |
-
|
409 |
-
if args.local_rank == 0:
|
410 |
-
prepare_tmp(args.pitch_online_dir)
|
411 |
-
|
412 |
-
trainset = TTSDataset(audiopaths_and_text=args.training_files, **vars(args))
|
413 |
-
valset = TTSDataset(audiopaths_and_text=args.validation_files, **vars(args))
|
414 |
-
|
415 |
-
if distributed_run:
|
416 |
-
train_sampler = RepeatedDistributedSampler(args.trainloader_repeats,
|
417 |
-
trainset, drop_last=True)
|
418 |
-
val_sampler = DistributedSampler(valset)
|
419 |
-
shuffle = False
|
420 |
-
else:
|
421 |
-
train_sampler, val_sampler, shuffle = None, None, False ########### was True
|
422 |
-
|
423 |
-
# 4 workers are optimal on DGX-1 (from epoch 2 onwards)
|
424 |
-
kw = {'num_workers': args.num_workers, 'batch_size': args.batch_size,
|
425 |
-
'collate_fn': collate_fn}
|
426 |
-
train_loader = RepeatedDataLoader(args.trainloader_repeats, trainset,
|
427 |
-
shuffle=shuffle, drop_last=True,
|
428 |
-
sampler=train_sampler, pin_memory=True,
|
429 |
-
persistent_workers=True, **kw)
|
430 |
-
val_loader = DataLoader(valset, shuffle=False, sampler=val_sampler,
|
431 |
-
pin_memory=False, **kw)
|
432 |
-
if args.ema_decay:
|
433 |
-
mt_ema_params = init_multi_tensor_ema(model, ema_model)
|
434 |
-
|
435 |
-
model.train()
|
436 |
-
bmark_stats = BenchmarkStats()
|
437 |
-
|
438 |
-
torch.cuda.synchronize()
|
439 |
-
for epoch in range(start_epoch, args.epochs + 1):
|
440 |
-
epoch_start_time = time.perf_counter()
|
441 |
-
|
442 |
-
epoch_loss = 0.0
|
443 |
-
epoch_mel_loss = 0.0
|
444 |
-
epoch_num_frames = 0
|
445 |
-
epoch_frames_per_sec = 0.0
|
446 |
-
|
447 |
-
if distributed_run:
|
448 |
-
train_loader.sampler.set_epoch(epoch)
|
449 |
-
|
450 |
-
iter_loss = 0
|
451 |
-
iter_num_frames = 0
|
452 |
-
iter_meta = {}
|
453 |
-
iter_start_time = time.perf_counter()
|
454 |
-
|
455 |
-
epoch_iter = 1
|
456 |
-
for batch, accum_step in zip(train_loader,
|
457 |
-
cycle(range(1, args.grad_accumulation + 1))):
|
458 |
-
if accum_step == 1:
|
459 |
-
adjust_learning_rate(total_iter, optimizer, args.learning_rate,
|
460 |
-
args.warmup_steps)
|
461 |
-
|
462 |
-
model.zero_grad(set_to_none=True)
|
463 |
-
|
464 |
-
x, y, num_frames = batch_to_gpu(batch)
|
465 |
-
|
466 |
-
with torch.cuda.amp.autocast(enabled=args.amp):
|
467 |
-
y_pred = model(x)
|
468 |
-
loss, meta = criterion(y_pred, y)
|
469 |
-
|
470 |
-
if (args.kl_loss_start_epoch is not None
|
471 |
-
and epoch >= args.kl_loss_start_epoch):
|
472 |
-
|
473 |
-
if args.kl_loss_start_epoch == epoch and epoch_iter == 1:
|
474 |
-
print('Begin hard_attn loss')
|
475 |
-
|
476 |
-
_, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred
|
477 |
-
binarization_loss = attention_kl_loss(attn_hard, attn_soft)
|
478 |
-
kl_weight = min((epoch - args.kl_loss_start_epoch) / args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight
|
479 |
-
meta['kl_loss'] = binarization_loss.clone().detach() * kl_weight
|
480 |
-
loss += kl_weight * binarization_loss
|
481 |
-
|
482 |
-
else:
|
483 |
-
meta['kl_loss'] = torch.zeros_like(loss)
|
484 |
-
kl_weight = 0
|
485 |
-
binarization_loss = 0
|
486 |
-
|
487 |
-
loss /= args.grad_accumulation
|
488 |
-
meta = {k: v / args.grad_accumulation
|
489 |
-
for k, v in meta.items()}
|
490 |
-
|
491 |
-
if args.amp:
|
492 |
-
scaler.scale(loss).backward()
|
493 |
-
else:
|
494 |
-
loss.backward()
|
495 |
-
|
496 |
-
if distributed_run:
|
497 |
-
reduced_loss = reduce_tensor(loss.data, args.world_size).item()
|
498 |
-
reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
|
499 |
-
meta = {k: reduce_tensor(v, args.world_size) for k, v in meta.items()}
|
500 |
-
else:
|
501 |
-
reduced_loss = loss.item()
|
502 |
-
reduced_num_frames = num_frames.item()
|
503 |
-
if np.isnan(reduced_loss):
|
504 |
-
raise Exception("loss is NaN")
|
505 |
-
|
506 |
-
iter_loss += reduced_loss
|
507 |
-
iter_num_frames += reduced_num_frames
|
508 |
-
iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}
|
509 |
-
if accum_step % args.grad_accumulation == 0:
|
510 |
-
|
511 |
-
logger.log_grads_tb(total_iter, model)
|
512 |
-
if args.amp:
|
513 |
-
scaler.unscale_(optimizer)
|
514 |
-
torch.nn.utils.clip_grad_norm_(
|
515 |
-
model.parameters(), args.grad_clip_thresh)
|
516 |
-
scaler.step(optimizer)
|
517 |
-
scaler.update()
|
518 |
-
else:
|
519 |
-
torch.nn.utils.clip_grad_norm_(
|
520 |
-
model.parameters(), args.grad_clip_thresh)
|
521 |
-
optimizer.step()
|
522 |
-
|
523 |
-
if args.ema_decay > 0.0:
|
524 |
-
apply_multi_tensor_ema(args.ema_decay, *mt_ema_params)
|
525 |
-
|
526 |
-
iter_mel_loss = iter_meta['mel_loss'].item()
|
527 |
-
iter_kl_loss = iter_meta['kl_loss'].item()
|
528 |
-
iter_time = time.perf_counter() - iter_start_time
|
529 |
-
epoch_frames_per_sec += iter_num_frames / iter_time
|
530 |
-
epoch_loss += iter_loss
|
531 |
-
epoch_num_frames += iter_num_frames
|
532 |
-
epoch_mel_loss += iter_mel_loss
|
533 |
-
|
534 |
-
num_iters = len(train_loader) // args.grad_accumulation
|
535 |
-
log((epoch, epoch_iter, num_iters), tb_total_steps=total_iter,
|
536 |
-
subset='train', data=OrderedDict([
|
537 |
-
('loss', iter_loss),
|
538 |
-
('mel_loss', iter_mel_loss),
|
539 |
-
('kl_loss', iter_kl_loss),
|
540 |
-
('kl_weight', kl_weight),
|
541 |
-
('frames/s', iter_num_frames / iter_time),
|
542 |
-
('took', iter_time),
|
543 |
-
('lrate', optimizer.param_groups[0]['lr'])]),
|
544 |
-
)
|
545 |
-
|
546 |
-
iter_loss = 0
|
547 |
-
iter_num_frames = 0
|
548 |
-
iter_meta = {}
|
549 |
-
iter_start_time = time.perf_counter()
|
550 |
-
|
551 |
-
if epoch_iter == num_iters:
|
552 |
-
break
|
553 |
-
epoch_iter += 1
|
554 |
-
total_iter += 1
|
555 |
-
|
556 |
-
# Finished epoch
|
557 |
-
epoch_loss /= epoch_iter
|
558 |
-
epoch_mel_loss /= epoch_iter
|
559 |
-
epoch_time = time.perf_counter() - epoch_start_time
|
560 |
-
log((epoch,), tb_total_steps=None, subset='train_avg',
|
561 |
-
data=OrderedDict([
|
562 |
-
('loss', epoch_loss),
|
563 |
-
('mel_loss', epoch_mel_loss),
|
564 |
-
('frames/s', epoch_num_frames / epoch_time),
|
565 |
-
('took', epoch_time)]),
|
566 |
-
)
|
567 |
-
bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss,
|
568 |
-
epoch_time)
|
569 |
-
|
570 |
-
if epoch % args.validation_freq == 0:
|
571 |
-
validate(model, epoch, total_iter, criterion, val_loader,
|
572 |
-
distributed_run, batch_to_gpu, ema=False, local_rank=args.local_rank)
|
573 |
-
|
574 |
-
if args.ema_decay > 0:
|
575 |
-
validate(ema_model, epoch, total_iter, criterion, val_loader,
|
576 |
-
distributed_run, batch_to_gpu, args.local_rank, ema=True)
|
577 |
-
|
578 |
-
# save before making sched.step() for proper loading of LR
|
579 |
-
checkpointer.maybe_save(args, model, ema_model, optimizer, scaler,
|
580 |
-
epoch, total_iter, model_config)
|
581 |
-
logger.flush()
|
582 |
-
|
583 |
-
# Finished training
|
584 |
-
if len(bmark_stats) > 0:
|
585 |
-
log((), tb_total_steps=None, subset='train_avg',
|
586 |
-
data=bmark_stats.get(args.benchmark_epochs_num))
|
587 |
-
|
588 |
-
validate(model, None, total_iter, criterion, val_loader, distributed_run,
|
589 |
-
batch_to_gpu)
|
590 |
-
|
591 |
-
|
592 |
-
if __name__ == '__main__':
|
593 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|