Spaces:
Running
Running
katrihiovain
commited on
Commit
·
95b5cf1
1
Parent(s):
2b63853
removed unnecessary files and updated app_py
Browse files- README1.md +0 -15
- app.py +1 -1
- common/text/symbols_2.py +0 -64
- common/text/symbols_BACKUP_MARCH_2024.py +0 -64
- common/text/symbols_ORIGINAL.py +0 -65
- common/text/symbols_backup.py +0 -54
- common/text/symbols_sme.py +0 -64
- common/text/symbols_sme_1.py +0 -64
- common/text/symbols_smj.py +0 -48
- common/utils_hfg.py +0 -14
- common/utils_ok.py +0 -291
- fastpitch/data_function (copy).py.txt +0 -425
- fastpitch/data_function_model_py.zip +0 -3
- fastpitch/utils_trainplot_transformers.zip +0 -3
- fastpitch/utils_trainplot_transformers/train_1_with_plot.py +0 -591
- fastpitch/utils_trainplot_transformers/transformer.py +0 -213
- fastpitch/utils_trainplot_transformers/transformer_jit.py +0 -255
- fastpitch/utils_trainplot_transformers/utils.py +0 -291
- gradio_gui.py +0 -74
- gradio_gui_katri.py +0 -73
- prepare_dataset.py +0 -180
- run_training_cluster_s.sh +0 -33
- scripts/docker/build.sh +0 -3
- scripts/docker/interactive.sh +0 -5
- scripts/download_cmudict.sh +0 -10
- scripts/download_dataset.sh +0 -17
- scripts/download_models.sh +0 -63
- scripts/inference_benchmark.sh +0 -16
- scripts/inference_example.sh +0 -78
- scripts/prepare_dataset.sh +0 -19
- scripts/train.sh +0 -100
- scripts/train_multilang.sh +0 -110
- train_1_with_plot_multilang.py +0 -593
README1.md
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
# FastPitchMulti
|
| 2 |
-
Experimental multi-lingual FastPitch
|
| 3 |
-
|
| 4 |
-
What's done:
|
| 5 |
-
- [x] Conditioning on language and speaker labels
|
| 6 |
-
- [x] Dataset and preprocessing of Sámi data
|
| 7 |
-
- [x] Combined character set for the Sámi languages
|
| 8 |
-
- [x] Train a model on Sámi languages
|
| 9 |
-
- [x] Selecting Estonian data
|
| 10 |
-
- [x] Processing Estonian data
|
| 11 |
-
- [ ] Train a model on Sámi x 3, Finnish, Estonian
|
| 12 |
-
|
| 13 |
-
Ideas:
|
| 14 |
-
- Move the language embedding to the very beginning of the encoder
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -11,7 +11,7 @@ speakers={"aj0": 0,
|
|
| 11 |
"aj1": 1,
|
| 12 |
"am": 2,
|
| 13 |
"bi": 3,
|
| 14 |
-
"kd": 4,
|
| 15 |
"ln": 5,
|
| 16 |
"lo": 6,
|
| 17 |
"ms": 7,
|
|
|
|
| 11 |
"aj1": 1,
|
| 12 |
"am": 2,
|
| 13 |
"bi": 3,
|
| 14 |
+
#"kd": 4,
|
| 15 |
"ln": 5,
|
| 16 |
"lo": 6,
|
| 17 |
"ms": 7,
|
common/text/symbols_2.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
| 1 |
-
""" from https://github.com/keithito/tacotron """
|
| 2 |
-
|
| 3 |
-
'''
|
| 4 |
-
Defines the set of symbols used in text input to the model.
|
| 5 |
-
|
| 6 |
-
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
| 7 |
-
from .cmudict import valid_symbols
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
| 11 |
-
_arpabet = ['@' + s for s in valid_symbols]
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def get_symbols(symbol_set='english_basic'):
|
| 15 |
-
if symbol_set == 'english_basic':
|
| 16 |
-
_pad = '_'
|
| 17 |
-
_punctuation = '!\'(),.:;? '
|
| 18 |
-
_special = '-'
|
| 19 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 20 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
| 21 |
-
elif symbol_set == 'english_basic_lowercase':
|
| 22 |
-
_pad = '_'
|
| 23 |
-
_punctuation = '!\'"(),.:;? '
|
| 24 |
-
_special = '-'
|
| 25 |
-
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
| 26 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
| 27 |
-
elif symbol_set == 'english_expanded':
|
| 28 |
-
_punctuation = '!\'",.:;? '
|
| 29 |
-
_math = '#%&*+-/[]()'
|
| 30 |
-
_special = '_@©°½—₩€$'
|
| 31 |
-
_accented = 'áçéêëñöøćž'
|
| 32 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 33 |
-
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
| 34 |
-
elif symbol_set == 'smj_expanded':
|
| 35 |
-
_punctuation = '!\'",.:;?- '
|
| 36 |
-
_math = '#%&*+-/[]()'
|
| 37 |
-
_special = '_@©°½—₩€$'
|
| 38 |
-
# _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
| 39 |
-
_accented = 'áçéêëñöø' #also north sámi letters...
|
| 40 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 41 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
| 42 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
| 43 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
| 44 |
-
elif symbol_set == 'sme_expanded':
|
| 45 |
-
_punctuation = '!\'",.:;?- '
|
| 46 |
-
_math = '#%&*+-/[]()'
|
| 47 |
-
_special = '_@©°½—₩€$'
|
| 48 |
-
_accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
|
| 49 |
-
# _accented = 'áçéêëñöø' #also north sámi letters...
|
| 50 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 51 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋoøöpqrstuvwxyz'
|
| 52 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
| 53 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
| 54 |
-
else:
|
| 55 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
| 56 |
-
|
| 57 |
-
return symbols
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def get_pad_idx(symbol_set='english_basic'):
|
| 61 |
-
if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
|
| 62 |
-
return 0
|
| 63 |
-
else:
|
| 64 |
-
raise Exception("{} symbol set not used yet".format(symbol_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/text/symbols_BACKUP_MARCH_2024.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
| 1 |
-
""" from https://github.com/keithito/tacotron """
|
| 2 |
-
|
| 3 |
-
'''
|
| 4 |
-
Defines the set of symbols used in text input to the model.
|
| 5 |
-
|
| 6 |
-
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
| 7 |
-
from .cmudict import valid_symbols
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
| 11 |
-
_arpabet = ['@' + s for s in valid_symbols]
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def get_symbols(symbol_set='english_basic'):
|
| 15 |
-
if symbol_set == 'english_basic':
|
| 16 |
-
_pad = '_'
|
| 17 |
-
_punctuation = '!\'(),.:;? '
|
| 18 |
-
_special = '-'
|
| 19 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 20 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
| 21 |
-
elif symbol_set == 'english_basic_lowercase':
|
| 22 |
-
_pad = '_'
|
| 23 |
-
_punctuation = '!\'"(),.:;? '
|
| 24 |
-
_special = '-'
|
| 25 |
-
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
| 26 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
| 27 |
-
elif symbol_set == 'english_expanded':
|
| 28 |
-
_punctuation = '!\'",.:;? '
|
| 29 |
-
_math = '#%&*+-/[]()'
|
| 30 |
-
_special = '_@©°½—₩€$'
|
| 31 |
-
_accented = 'áçéêëñöøćž'
|
| 32 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 33 |
-
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
| 34 |
-
elif symbol_set == 'smj_expanded':
|
| 35 |
-
_punctuation = '!\'",.:;?- '
|
| 36 |
-
_math = '#%&*+-/[]()'
|
| 37 |
-
_special = '_@©°½—₩€$'
|
| 38 |
-
# _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
| 39 |
-
_accented = 'áçéêëñöø' #also north sámi letters...
|
| 40 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 41 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTŦUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz' ########################## Ŧ ########################
|
| 42 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
| 43 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
| 44 |
-
elif symbol_set == 'sme_expanded':
|
| 45 |
-
_punctuation = '!\'",.:;?- '
|
| 46 |
-
_math = '#%&*+-/[]()'
|
| 47 |
-
_special = '_@©°½—₩€$'
|
| 48 |
-
_accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
|
| 49 |
-
# _accented = 'áçéêëñöø' #also north sámi letters...
|
| 50 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 51 |
-
_letters = 'AÁÆÅÄBCČDĐEFGHIJKLMNŊOØÖPQRSŠTŦUVWXYZŽaáæåäbcčdđefghijklmnŋoøöpqrsštŧuvwxyzž'
|
| 52 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
| 53 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
| 54 |
-
else:
|
| 55 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
| 56 |
-
|
| 57 |
-
return symbols
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def get_pad_idx(symbol_set='english_basic'):
|
| 61 |
-
if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
|
| 62 |
-
return 0
|
| 63 |
-
else:
|
| 64 |
-
raise Exception("{} symbol set not used yet".format(symbol_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/text/symbols_ORIGINAL.py
DELETED
|
@@ -1,65 +0,0 @@
|
|
| 1 |
-
""" from https://github.com/keithito/tacotron """
|
| 2 |
-
|
| 3 |
-
'''
|
| 4 |
-
Defines the set of symbols used in text input to the model.
|
| 5 |
-
|
| 6 |
-
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
| 7 |
-
from .cmudict import valid_symbols
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
| 11 |
-
_arpabet = ['@' + s for s in valid_symbols]
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def get_symbols(symbol_set='english_basic'):
|
| 15 |
-
if symbol_set == 'english_basic':
|
| 16 |
-
_pad = '_'
|
| 17 |
-
_punctuation = '!\'(),.:;? '
|
| 18 |
-
_special = '-'
|
| 19 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 20 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
| 21 |
-
elif symbol_set == 'english_basic_lowercase':
|
| 22 |
-
_pad = '_'
|
| 23 |
-
_punctuation = '!\'"(),.:;? '
|
| 24 |
-
_special = '-'
|
| 25 |
-
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
| 26 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
| 27 |
-
elif symbol_set == 'english_expanded':
|
| 28 |
-
_punctuation = '!\'",.:;? '
|
| 29 |
-
_math = '#%&*+-/[]()'
|
| 30 |
-
_special = '_@©°½—₩€$'
|
| 31 |
-
_accented = 'áçéêëñöøćž'
|
| 32 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 33 |
-
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
| 34 |
-
elif symbol_set == 'smj_expanded':
|
| 35 |
-
_punctuation = '!\'",.:;?- '
|
| 36 |
-
_math = '#%&*+-/[]()'
|
| 37 |
-
_special = '_@©°½—₩€$'
|
| 38 |
-
# _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
| 39 |
-
_accented = 'áçéêëñöø' #also north sámi letters...
|
| 40 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 41 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTŦUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
| 42 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
| 43 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
| 44 |
-
elif symbol_set == 'sme_expanded':
|
| 45 |
-
_punctuation = '!\'",.:;?- '
|
| 46 |
-
_math = '#%&*+-/[]()'
|
| 47 |
-
_special = '_@©°½—₩€$'
|
| 48 |
-
_accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
|
| 49 |
-
# _accented = 'áçéêëñöø' #also north sámi letters...
|
| 50 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 51 |
-
_letters = 'AÁÆÅÄBCČDĐEFGHIJKLMNŊOØÖPQRSŠTŦUVWXYZŽaáæåäbcčdđefghijklmnŋoøöpqrsštŧuvwxyzž'
|
| 52 |
-
# _letters = 'AÁÆÅÄBCDĐEFGHIJKLMNŊOØÖPQRSŠTŦUVWXYZŽaáæåäbcčdđefghijklmnŋoøöpqrsštŧuvwxyzž'
|
| 53 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
| 54 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
| 55 |
-
else:
|
| 56 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
| 57 |
-
|
| 58 |
-
return symbols
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def get_pad_idx(symbol_set='english_basic'):
|
| 62 |
-
if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
|
| 63 |
-
return 0
|
| 64 |
-
else:
|
| 65 |
-
raise Exception("{} symbol set not used yet".format(symbol_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/text/symbols_backup.py
DELETED
|
@@ -1,54 +0,0 @@
|
|
| 1 |
-
""" from https://github.com/keithito/tacotron """
|
| 2 |
-
|
| 3 |
-
'''
|
| 4 |
-
Defines the set of symbols used in text input to the model.
|
| 5 |
-
|
| 6 |
-
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
| 7 |
-
from .cmudict import valid_symbols
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
| 11 |
-
_arpabet = ['@' + s for s in valid_symbols]
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def get_symbols(symbol_set='english_basic'):
|
| 15 |
-
if symbol_set == 'english_basic':
|
| 16 |
-
_pad = '_'
|
| 17 |
-
_punctuation = '!\'(),.:;? '
|
| 18 |
-
_special = '-'
|
| 19 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 20 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
| 21 |
-
elif symbol_set == 'english_basic_lowercase':
|
| 22 |
-
_pad = '_'
|
| 23 |
-
_punctuation = '!\'"(),.:;? '
|
| 24 |
-
_special = '-'
|
| 25 |
-
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
| 26 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
| 27 |
-
elif symbol_set == 'english_expanded':
|
| 28 |
-
_punctuation = '!\'",.:;? '
|
| 29 |
-
_math = '#%&*+-/[]()'
|
| 30 |
-
_special = '_@©°½—₩€$'
|
| 31 |
-
_accented = 'áçéêëñöøćž'
|
| 32 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 33 |
-
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
| 34 |
-
elif symbol_set == 'smj_expanded':
|
| 35 |
-
_punctuation = '!\'",.:;?- '
|
| 36 |
-
_math = '#%&*+-/[]()'
|
| 37 |
-
_special = '_@©°½—₩€$'
|
| 38 |
-
# _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
| 39 |
-
_accented = 'áçéêëñöø' #also north sámi letters...
|
| 40 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 41 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
| 42 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
| 43 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
| 44 |
-
else:
|
| 45 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
| 46 |
-
|
| 47 |
-
return symbols
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def get_pad_idx(symbol_set='english_basic'):
|
| 51 |
-
if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded'}:
|
| 52 |
-
return 0
|
| 53 |
-
else:
|
| 54 |
-
raise Exception("{} symbol set not used yet".format(symbol_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/text/symbols_sme.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
| 1 |
-
""" from https://github.com/keithito/tacotron """
|
| 2 |
-
|
| 3 |
-
'''
|
| 4 |
-
Defines the set of symbols used in text input to the model.
|
| 5 |
-
|
| 6 |
-
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
| 7 |
-
from .cmudict import valid_symbols
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
| 11 |
-
_arpabet = ['@' + s for s in valid_symbols]
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def get_symbols(symbol_set='english_basic'):
|
| 15 |
-
if symbol_set == 'english_basic':
|
| 16 |
-
_pad = '_'
|
| 17 |
-
_punctuation = '!\'(),.:;? '
|
| 18 |
-
_special = '-'
|
| 19 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 20 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
| 21 |
-
elif symbol_set == 'english_basic_lowercase':
|
| 22 |
-
_pad = '_'
|
| 23 |
-
_punctuation = '!\'"(),.:;? '
|
| 24 |
-
_special = '-'
|
| 25 |
-
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
| 26 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
| 27 |
-
elif symbol_set == 'english_expanded':
|
| 28 |
-
_punctuation = '!\'",.:;? '
|
| 29 |
-
_math = '#%&*+-/[]()'
|
| 30 |
-
_special = '_@©°½—₩€$'
|
| 31 |
-
_accented = 'áçéêëñöøćž'
|
| 32 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 33 |
-
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
| 34 |
-
elif symbol_set == 'smj_expanded':
|
| 35 |
-
_punctuation = '!\'",.:;?- '
|
| 36 |
-
_math = '#%&*+-/[]()'
|
| 37 |
-
_special = '_@©°½—₩€$'
|
| 38 |
-
# _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
| 39 |
-
_accented = 'áçéêëñöø' #also north sámi letters...
|
| 40 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 41 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
| 42 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
| 43 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
| 44 |
-
elif symbol_set == 'sme_expanded':
|
| 45 |
-
_punctuation = '!\'",.:;?- '
|
| 46 |
-
_math = '#%&*+-/[]()'
|
| 47 |
-
_special = '_@©°½—₩€$'
|
| 48 |
-
_accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
|
| 49 |
-
# _accented = 'áçéêëñöø' #also north sámi letters...
|
| 50 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 51 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋoøöpqrstuvwxyz'
|
| 52 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
| 53 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
| 54 |
-
else:
|
| 55 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
| 56 |
-
|
| 57 |
-
return symbols
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def get_pad_idx(symbol_set='english_basic'):
|
| 61 |
-
if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
|
| 62 |
-
return 0
|
| 63 |
-
else:
|
| 64 |
-
raise Exception("{} symbol set not used yet".format(symbol_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/text/symbols_sme_1.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
| 1 |
-
""" from https://github.com/keithito/tacotron """
|
| 2 |
-
|
| 3 |
-
'''
|
| 4 |
-
Defines the set of symbols used in text input to the model.
|
| 5 |
-
|
| 6 |
-
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
| 7 |
-
from .cmudict import valid_symbols
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
| 11 |
-
_arpabet = ['@' + s for s in valid_symbols]
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def get_symbols(symbol_set='english_basic'):
|
| 15 |
-
if symbol_set == 'english_basic':
|
| 16 |
-
_pad = '_'
|
| 17 |
-
_punctuation = '!\'(),.:;? '
|
| 18 |
-
_special = '-'
|
| 19 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 20 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
| 21 |
-
elif symbol_set == 'english_basic_lowercase':
|
| 22 |
-
_pad = '_'
|
| 23 |
-
_punctuation = '!\'"(),.:;? '
|
| 24 |
-
_special = '-'
|
| 25 |
-
_letters = 'abcdefghijklmnopqrstuvwxyz'
|
| 26 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
| 27 |
-
elif symbol_set == 'english_expanded':
|
| 28 |
-
_punctuation = '!\'",.:;? '
|
| 29 |
-
_math = '#%&*+-/[]()'
|
| 30 |
-
_special = '_@©°½—₩€$'
|
| 31 |
-
_accented = 'áçéêëñöøćž'
|
| 32 |
-
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 33 |
-
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
| 34 |
-
elif symbol_set == 'smj_expanded':
|
| 35 |
-
_punctuation = '!\'",.:;?- '
|
| 36 |
-
_math = '#%&*+-/[]()'
|
| 37 |
-
_special = '_@©°½—₩€$'
|
| 38 |
-
# _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
| 39 |
-
_accented = 'áçéêëñöø' #also north sámi letters...
|
| 40 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 41 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
| 42 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
| 43 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
| 44 |
-
elif symbol_set == 'sme_expanded':
|
| 45 |
-
_punctuation = '!\'",.:;?- '
|
| 46 |
-
_math = '#%&*+-/[]()'
|
| 47 |
-
_special = '_@©°½—₩€$'
|
| 48 |
-
_accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
|
| 49 |
-
# _accented = 'áçéêëñöø' #also north sámi letters...
|
| 50 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 51 |
-
_letters = 'AÁÆÅÄBCDĐEFGHIJKLMNŊOØÖPQRSŠTUVWXYZŽaáæåäbcdđefghijklmnŋoøöpqrsštŧuvwxyzž'
|
| 52 |
-
# symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
|
| 53 |
-
symbols = list(_punctuation + _letters) + _arpabet
|
| 54 |
-
else:
|
| 55 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
| 56 |
-
|
| 57 |
-
return symbols
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def get_pad_idx(symbol_set='english_basic'):
|
| 61 |
-
if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
|
| 62 |
-
return 0
|
| 63 |
-
else:
|
| 64 |
-
raise Exception("{} symbol set not used yet".format(symbol_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/text/symbols_smj.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
""" from https://github.com/keithito/tacotron """
|
| 2 |
-
|
| 3 |
-
'''
|
| 4 |
-
Defines the set of symbols used in text input to the model.
|
| 5 |
-
|
| 6 |
-
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
| 7 |
-
from .cmudict import valid_symbols
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
| 11 |
-
_arpabet = ['@' + s for s in valid_symbols]
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def get_symbols(symbol_set='smj_basic'):
|
| 15 |
-
if symbol_set == 'smj_basic':
|
| 16 |
-
_pad = '_'
|
| 17 |
-
_punctuation = '!\'(),.:;? '
|
| 18 |
-
_special = '-'
|
| 19 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 20 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
| 21 |
-
# LETTERS = 'AaÁáBbCcDdEeFfGgHhIiJjKkLlMmNnŊŋOoPpRrSsTtUuVvZzŃńÑñÆæØøÅåÄäÖö'
|
| 22 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
| 23 |
-
elif symbol_set == 'smj_basic_lowercase':
|
| 24 |
-
_pad = '_'
|
| 25 |
-
_punctuation = '!\'"(),.:;? '
|
| 26 |
-
_special = '-'
|
| 27 |
-
# _letters = 'abcdefghijklmnopqrstuvwxyz'
|
| 28 |
-
_letters = 'aáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
| 29 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
| 30 |
-
elif symbol_set == 'smj_expanded':
|
| 31 |
-
_punctuation = '!\'",.:;? '
|
| 32 |
-
_math = '#%&*+-/[]()'
|
| 33 |
-
_special = '_@©°½—₩€$'
|
| 34 |
-
_accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
|
| 35 |
-
# _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 36 |
-
_letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
|
| 37 |
-
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
| 38 |
-
else:
|
| 39 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
| 40 |
-
|
| 41 |
-
return symbols
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def get_pad_idx(symbol_set='smj_basic'):
|
| 45 |
-
if symbol_set in {'smj_basic', 'smj_basic_lowercase'}:
|
| 46 |
-
return 0
|
| 47 |
-
else:
|
| 48 |
-
raise Exception("{} symbol set not used yet".format(symbol_set))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/utils_hfg.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
##############################################################################
|
| 3 |
-
# Foreing utils.py from HiFi-GAN
|
| 4 |
-
##############################################################################
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def init_weights(m, mean=0.0, std=0.01):
|
| 8 |
-
classname = m.__class__.__name__
|
| 9 |
-
if classname.find("Conv") != -1:
|
| 10 |
-
m.weight.data.normal_(mean, std)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def get_padding(kernel_size, dilation=1):
|
| 14 |
-
return int((kernel_size*dilation - dilation)/2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/utils_ok.py
DELETED
|
@@ -1,291 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
# MIT License
|
| 16 |
-
#
|
| 17 |
-
# Copyright (c) 2020 Jungil Kong
|
| 18 |
-
#
|
| 19 |
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 20 |
-
# of this software and associated documentation files (the "Software"), to deal
|
| 21 |
-
# in the Software without restriction, including without limitation the rights
|
| 22 |
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 23 |
-
# copies of the Software, and to permit persons to whom the Software is
|
| 24 |
-
# furnished to do so, subject to the following conditions:
|
| 25 |
-
#
|
| 26 |
-
# The above copyright notice and this permission notice shall be included in all
|
| 27 |
-
# copies or substantial portions of the Software.
|
| 28 |
-
#
|
| 29 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 30 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 31 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 32 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 33 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 34 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 35 |
-
# SOFTWARE.
|
| 36 |
-
|
| 37 |
-
# The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
|
| 38 |
-
# init_weights, get_padding, AttrDict
|
| 39 |
-
|
| 40 |
-
import ctypes
|
| 41 |
-
import glob
|
| 42 |
-
import os
|
| 43 |
-
import re
|
| 44 |
-
import shutil
|
| 45 |
-
import warnings
|
| 46 |
-
from collections import defaultdict, OrderedDict
|
| 47 |
-
from pathlib import Path
|
| 48 |
-
from typing import Optional
|
| 49 |
-
|
| 50 |
-
import librosa
|
| 51 |
-
import numpy as np
|
| 52 |
-
|
| 53 |
-
import torch
|
| 54 |
-
import torch.distributed as dist
|
| 55 |
-
from scipy.io.wavfile import read
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def mask_from_lens(lens, max_len: Optional[int] = None):
|
| 59 |
-
if max_len is None:
|
| 60 |
-
max_len = lens.max()
|
| 61 |
-
ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
|
| 62 |
-
mask = torch.lt(ids, lens.unsqueeze(1))
|
| 63 |
-
return mask
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def load_wav(full_path, torch_tensor=False):
|
| 67 |
-
import soundfile # flac
|
| 68 |
-
data, sampling_rate = soundfile.read(full_path, dtype='int16')
|
| 69 |
-
if torch_tensor:
|
| 70 |
-
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
| 71 |
-
else:
|
| 72 |
-
return data, sampling_rate
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def load_wav_to_torch(full_path, force_sampling_rate=None):
|
| 76 |
-
if force_sampling_rate is not None:
|
| 77 |
-
data, sampling_rate = librosa.load(full_path, sr=force_sampling_rate)
|
| 78 |
-
else:
|
| 79 |
-
sampling_rate, data = read(full_path)
|
| 80 |
-
|
| 81 |
-
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def load_filepaths_and_text(dataset_path, fnames, has_speakers=False, split="|"):
|
| 85 |
-
def split_line(root, line):
|
| 86 |
-
parts = line.strip().split(split)
|
| 87 |
-
if has_speakers:
|
| 88 |
-
paths, non_paths = parts[:-2], parts[-2:]
|
| 89 |
-
else:
|
| 90 |
-
paths, non_paths = parts[:-1], parts[-1:]
|
| 91 |
-
return tuple(str(Path(root, p)) for p in paths) + tuple(non_paths)
|
| 92 |
-
|
| 93 |
-
fpaths_and_text = []
|
| 94 |
-
for fname in fnames:
|
| 95 |
-
with open(fname, encoding='utf-8') as f:
|
| 96 |
-
fpaths_and_text += [split_line(dataset_path, line) for line in f]
|
| 97 |
-
return fpaths_and_text
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def to_gpu(x):
|
| 101 |
-
x = x.contiguous()
|
| 102 |
-
return x.cuda(non_blocking=True) if torch.cuda.is_available() else x
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
def l2_promote():
|
| 106 |
-
_libcudart = ctypes.CDLL('libcudart.so')
|
| 107 |
-
# Set device limit on the current device
|
| 108 |
-
# cudaLimitMaxL2FetchGranularity = 0x05
|
| 109 |
-
pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))
|
| 110 |
-
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
|
| 111 |
-
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
|
| 112 |
-
assert pValue.contents.value == 128
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def prepare_tmp(path):
|
| 116 |
-
if path is None:
|
| 117 |
-
return
|
| 118 |
-
p = Path(path)
|
| 119 |
-
if p.is_dir():
|
| 120 |
-
warnings.warn(f'{p} exists. Removing...')
|
| 121 |
-
shutil.rmtree(p, ignore_errors=True)
|
| 122 |
-
p.mkdir(parents=False, exist_ok=False)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def print_once(*msg):
|
| 126 |
-
if not dist.is_initialized() or dist.get_rank() == 0:
|
| 127 |
-
print(*msg)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
def init_weights(m, mean=0.0, std=0.01):
|
| 131 |
-
classname = m.__class__.__name__
|
| 132 |
-
if classname.find("Conv") != -1:
|
| 133 |
-
m.weight.data.normal_(mean, std)
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def get_padding(kernel_size, dilation=1):
|
| 137 |
-
return int((kernel_size*dilation - dilation)/2)
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
class AttrDict(dict):
|
| 141 |
-
def __init__(self, *args, **kwargs):
|
| 142 |
-
super(AttrDict, self).__init__(*args, **kwargs)
|
| 143 |
-
self.__dict__ = self
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
class DefaultAttrDict(defaultdict):
|
| 147 |
-
def __init__(self, *args, **kwargs):
|
| 148 |
-
super(DefaultAttrDict, self).__init__(*args, **kwargs)
|
| 149 |
-
self.__dict__ = self
|
| 150 |
-
|
| 151 |
-
def __getattr__(self, item):
|
| 152 |
-
return self[item]
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
class BenchmarkStats:
|
| 156 |
-
""" Tracks statistics used for benchmarking. """
|
| 157 |
-
def __init__(self):
|
| 158 |
-
self.num_frames = []
|
| 159 |
-
self.losses = []
|
| 160 |
-
self.mel_losses = []
|
| 161 |
-
self.took = []
|
| 162 |
-
|
| 163 |
-
def update(self, num_frames, losses, mel_losses, took):
|
| 164 |
-
self.num_frames.append(num_frames)
|
| 165 |
-
self.losses.append(losses)
|
| 166 |
-
self.mel_losses.append(mel_losses)
|
| 167 |
-
self.took.append(took)
|
| 168 |
-
|
| 169 |
-
def get(self, n_epochs):
|
| 170 |
-
frames_s = sum(self.num_frames[-n_epochs:]) / sum(self.took[-n_epochs:])
|
| 171 |
-
return {'frames/s': frames_s,
|
| 172 |
-
'loss': np.mean(self.losses[-n_epochs:]),
|
| 173 |
-
'mel_loss': np.mean(self.mel_losses[-n_epochs:]),
|
| 174 |
-
'took': np.mean(self.took[-n_epochs:]),
|
| 175 |
-
'benchmark_epochs_num': n_epochs}
|
| 176 |
-
|
| 177 |
-
def __len__(self):
|
| 178 |
-
return len(self.losses)
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
class Checkpointer:
|
| 182 |
-
|
| 183 |
-
def __init__(self, save_dir, keep_milestones=[]):
|
| 184 |
-
self.save_dir = save_dir
|
| 185 |
-
self.keep_milestones = keep_milestones
|
| 186 |
-
|
| 187 |
-
find = lambda name: [
|
| 188 |
-
(int(re.search("_(\d+).pt", fn).group(1)), fn)
|
| 189 |
-
for fn in glob.glob(f"{save_dir}/{name}_checkpoint_*.pt")]
|
| 190 |
-
|
| 191 |
-
tracked = sorted(find("FastPitch"), key=lambda t: t[0])
|
| 192 |
-
self.tracked = OrderedDict(tracked)
|
| 193 |
-
|
| 194 |
-
def last_checkpoint(self, output):
|
| 195 |
-
|
| 196 |
-
def corrupted(fpath):
|
| 197 |
-
try:
|
| 198 |
-
torch.load(fpath, map_location="cpu")
|
| 199 |
-
return False
|
| 200 |
-
except:
|
| 201 |
-
warnings.warn(f"Cannot load {fpath}")
|
| 202 |
-
return True
|
| 203 |
-
|
| 204 |
-
saved = sorted(
|
| 205 |
-
glob.glob(f"{output}/FastPitch_checkpoint_*.pt"),
|
| 206 |
-
key=lambda f: int(re.search("_(\d+).pt", f).group(1)))
|
| 207 |
-
|
| 208 |
-
if len(saved) >= 1 and not corrupted(saved[-1]):
|
| 209 |
-
return saved[-1]
|
| 210 |
-
elif len(saved) >= 2:
|
| 211 |
-
return saved[-2]
|
| 212 |
-
else:
|
| 213 |
-
return None
|
| 214 |
-
|
| 215 |
-
def maybe_load(self, model, optimizer, scaler, train_state, args,
|
| 216 |
-
ema_model=None):
|
| 217 |
-
|
| 218 |
-
assert args.checkpoint_path is None or args.resume is False, (
|
| 219 |
-
"Specify a single checkpoint source")
|
| 220 |
-
|
| 221 |
-
fpath = None
|
| 222 |
-
if args.checkpoint_path is not None:
|
| 223 |
-
fpath = args.checkpoint_path
|
| 224 |
-
self.tracked = OrderedDict() # Do not track/delete prev ckpts
|
| 225 |
-
elif args.resume:
|
| 226 |
-
fpath = self.last_checkpoint(args.output)
|
| 227 |
-
|
| 228 |
-
if fpath is None:
|
| 229 |
-
return
|
| 230 |
-
|
| 231 |
-
print_once(f"Loading model and optimizer state from {fpath}")
|
| 232 |
-
ckpt = torch.load(fpath, map_location="cpu")
|
| 233 |
-
train_state["epoch"] = ckpt["epoch"] + 1
|
| 234 |
-
train_state["total_iter"] = ckpt["iteration"]
|
| 235 |
-
|
| 236 |
-
no_pref = lambda sd: {re.sub("^module.", "", k): v for k, v in sd.items()}
|
| 237 |
-
unwrap = lambda m: getattr(m, "module", m)
|
| 238 |
-
|
| 239 |
-
unwrap(model).load_state_dict(no_pref(ckpt["state_dict"]))
|
| 240 |
-
|
| 241 |
-
if ema_model is not None:
|
| 242 |
-
unwrap(ema_model).load_state_dict(no_pref(ckpt["ema_state_dict"]))
|
| 243 |
-
|
| 244 |
-
optimizer.load_state_dict(ckpt["optimizer"])
|
| 245 |
-
|
| 246 |
-
if "scaler" in ckpt:
|
| 247 |
-
scaler.load_state_dict(ckpt["scaler"])
|
| 248 |
-
else:
|
| 249 |
-
warnings.warn("AMP scaler state missing from the checkpoint.")
|
| 250 |
-
|
| 251 |
-
def maybe_save(self, args, model, ema_model, optimizer, scaler, epoch,
|
| 252 |
-
total_iter, config):
|
| 253 |
-
|
| 254 |
-
intermediate = (args.epochs_per_checkpoint > 0
|
| 255 |
-
and epoch % args.epochs_per_checkpoint == 0)
|
| 256 |
-
final = epoch == args.epochs
|
| 257 |
-
|
| 258 |
-
if not intermediate and not final and epoch not in self.keep_milestones:
|
| 259 |
-
return
|
| 260 |
-
|
| 261 |
-
rank = 0
|
| 262 |
-
if dist.is_initialized():
|
| 263 |
-
dist.barrier()
|
| 264 |
-
rank = dist.get_rank()
|
| 265 |
-
|
| 266 |
-
if rank != 0:
|
| 267 |
-
return
|
| 268 |
-
|
| 269 |
-
unwrap = lambda m: getattr(m, "module", m)
|
| 270 |
-
ckpt = {"epoch": epoch,
|
| 271 |
-
"iteration": total_iter,
|
| 272 |
-
"config": config,
|
| 273 |
-
"train_setup": args.__dict__,
|
| 274 |
-
"state_dict": unwrap(model).state_dict(),
|
| 275 |
-
"optimizer": optimizer.state_dict(),
|
| 276 |
-
"scaler": scaler.state_dict()}
|
| 277 |
-
if ema_model is not None:
|
| 278 |
-
ckpt["ema_state_dict"] = unwrap(ema_model).state_dict()
|
| 279 |
-
|
| 280 |
-
fpath = Path(args.output, f"FastPitch_checkpoint_{epoch}.pt")
|
| 281 |
-
print(f"Saving model and optimizer state at epoch {epoch} to {fpath}")
|
| 282 |
-
torch.save(ckpt, fpath)
|
| 283 |
-
|
| 284 |
-
# Remove old checkpoints; keep milestones and the last two
|
| 285 |
-
self.tracked[epoch] = fpath
|
| 286 |
-
for epoch in set(list(self.tracked)[:-2]) - set(self.keep_milestones):
|
| 287 |
-
try:
|
| 288 |
-
os.remove(self.tracked[epoch])
|
| 289 |
-
except:
|
| 290 |
-
pass
|
| 291 |
-
del self.tracked[epoch]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fastpitch/data_function (copy).py.txt
DELETED
|
@@ -1,425 +0,0 @@
|
|
| 1 |
-
# *****************************************************************************
|
| 2 |
-
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Redistribution and use in source and binary forms, with or without
|
| 5 |
-
# modification, are permitted provided that the following conditions are met:
|
| 6 |
-
# * Redistributions of source code must retain the above copyright
|
| 7 |
-
# notice, this list of conditions and the following disclaimer.
|
| 8 |
-
# * Redistributions in binary form must reproduce the above copyright
|
| 9 |
-
# notice, this list of conditions and the following disclaimer in the
|
| 10 |
-
# documentation and/or other materials provided with the distribution.
|
| 11 |
-
# * Neither the name of the NVIDIA CORPORATION nor the
|
| 12 |
-
# names of its contributors may be used to endorse or promote products
|
| 13 |
-
# derived from this software without specific prior written permission.
|
| 14 |
-
#
|
| 15 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 16 |
-
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 17 |
-
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 18 |
-
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
| 19 |
-
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
| 20 |
-
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
| 21 |
-
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
| 22 |
-
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 23 |
-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
| 24 |
-
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 25 |
-
#
|
| 26 |
-
# *****************************************************************************
|
| 27 |
-
|
| 28 |
-
import functools
|
| 29 |
-
import json
|
| 30 |
-
import re
|
| 31 |
-
from pathlib import Path
|
| 32 |
-
|
| 33 |
-
import librosa
|
| 34 |
-
import numpy as np
|
| 35 |
-
import torch
|
| 36 |
-
import torch.nn.functional as F
|
| 37 |
-
from scipy import ndimage
|
| 38 |
-
from scipy.stats import betabinom
|
| 39 |
-
|
| 40 |
-
import common.layers as layers
|
| 41 |
-
from common.text.text_processing import TextProcessing
|
| 42 |
-
from common.utils import load_wav_to_torch, load_filepaths_and_text, to_gpu
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class BetaBinomialInterpolator:
|
| 46 |
-
"""Interpolates alignment prior matrices to save computation.
|
| 47 |
-
|
| 48 |
-
Calculating beta-binomial priors is costly. Instead cache popular sizes
|
| 49 |
-
and use img interpolation to get priors faster.
|
| 50 |
-
"""
|
| 51 |
-
def __init__(self, round_mel_len_to=100, round_text_len_to=20):
|
| 52 |
-
self.round_mel_len_to = round_mel_len_to
|
| 53 |
-
self.round_text_len_to = round_text_len_to
|
| 54 |
-
self.bank = functools.lru_cache(beta_binomial_prior_distribution)
|
| 55 |
-
|
| 56 |
-
def round(self, val, to):
|
| 57 |
-
return max(1, int(np.round((val + 1) / to))) * to
|
| 58 |
-
|
| 59 |
-
def __call__(self, w, h):
|
| 60 |
-
bw = self.round(w, to=self.round_mel_len_to)
|
| 61 |
-
bh = self.round(h, to=self.round_text_len_to)
|
| 62 |
-
ret = ndimage.zoom(self.bank(bw, bh).T, zoom=(w / bw, h / bh), order=1)
|
| 63 |
-
assert ret.shape[0] == w, ret.shape
|
| 64 |
-
assert ret.shape[1] == h, ret.shape
|
| 65 |
-
return ret
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling=1.0):
|
| 69 |
-
P = phoneme_count
|
| 70 |
-
M = mel_count
|
| 71 |
-
x = np.arange(0, P)
|
| 72 |
-
mel_text_probs = []
|
| 73 |
-
for i in range(1, M+1):
|
| 74 |
-
a, b = scaling * i, scaling * (M + 1 - i)
|
| 75 |
-
rv = betabinom(P, a, b)
|
| 76 |
-
mel_i_prob = rv.pmf(x)
|
| 77 |
-
mel_text_probs.append(mel_i_prob)
|
| 78 |
-
return torch.tensor(np.array(mel_text_probs))
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def estimate_pitch(wav, mel_len, method='pyin', normalize_mean=None,
|
| 82 |
-
normalize_std=None, n_formants=1):
|
| 83 |
-
|
| 84 |
-
if type(normalize_mean) is float or type(normalize_mean) is list:
|
| 85 |
-
normalize_mean = torch.tensor(normalize_mean)
|
| 86 |
-
|
| 87 |
-
if type(normalize_std) is float or type(normalize_std) is list:
|
| 88 |
-
normalize_std = torch.tensor(normalize_std)
|
| 89 |
-
|
| 90 |
-
if method == 'pyin':
|
| 91 |
-
|
| 92 |
-
snd, sr = librosa.load(wav)
|
| 93 |
-
pitch_mel, voiced_flag, voiced_probs = librosa.pyin(
|
| 94 |
-
snd, fmin=librosa.note_to_hz('C2'),
|
| 95 |
-
# fmax=librosa.note_to_hz('C7'), frame_length=1024)
|
| 96 |
-
fmax=400, frame_length=1024)
|
| 97 |
-
assert np.abs(mel_len - pitch_mel.shape[0]) <= 1.0
|
| 98 |
-
|
| 99 |
-
pitch_mel = np.where(np.isnan(pitch_mel), 0.0, pitch_mel)
|
| 100 |
-
pitch_mel = torch.from_numpy(pitch_mel).unsqueeze(0)
|
| 101 |
-
pitch_mel = F.pad(pitch_mel, (0, mel_len - pitch_mel.size(1)))
|
| 102 |
-
|
| 103 |
-
if n_formants > 1:
|
| 104 |
-
raise NotImplementedError
|
| 105 |
-
|
| 106 |
-
else:
|
| 107 |
-
raise ValueError
|
| 108 |
-
|
| 109 |
-
pitch_mel = pitch_mel.float()
|
| 110 |
-
|
| 111 |
-
if normalize_mean is not None:
|
| 112 |
-
assert normalize_std is not None
|
| 113 |
-
pitch_mel = normalize_pitch(pitch_mel, normalize_mean, normalize_std)
|
| 114 |
-
|
| 115 |
-
return pitch_mel
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def normalize_pitch(pitch, mean, std):
|
| 119 |
-
zeros = (pitch == 0.0)
|
| 120 |
-
pitch -= mean[:, None]
|
| 121 |
-
pitch /= std[:, None]
|
| 122 |
-
pitch[zeros] = 0.0
|
| 123 |
-
return pitch
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
class TTSDataset(torch.utils.data.Dataset):
|
| 127 |
-
"""
|
| 128 |
-
1) loads audio,text pairs
|
| 129 |
-
2) normalizes text and converts them to sequences of one-hot vectors
|
| 130 |
-
3) computes mel-spectrograms from audio files.
|
| 131 |
-
"""
|
| 132 |
-
def __init__(self,
|
| 133 |
-
dataset_path,
|
| 134 |
-
audiopaths_and_text,
|
| 135 |
-
text_cleaners,
|
| 136 |
-
n_mel_channels,
|
| 137 |
-
symbol_set='english_basic',
|
| 138 |
-
p_arpabet=1.0,
|
| 139 |
-
n_speakers=1,
|
| 140 |
-
load_mel_from_disk=True,
|
| 141 |
-
load_pitch_from_disk=True,
|
| 142 |
-
pitch_mean=214.72203, # LJSpeech defaults
|
| 143 |
-
pitch_std=65.72038,
|
| 144 |
-
max_wav_value=None,
|
| 145 |
-
sampling_rate=None,
|
| 146 |
-
filter_length=None,
|
| 147 |
-
hop_length=None,
|
| 148 |
-
win_length=None,
|
| 149 |
-
mel_fmin=None,
|
| 150 |
-
mel_fmax=None,
|
| 151 |
-
prepend_space_to_text=False,
|
| 152 |
-
append_space_to_text=False,
|
| 153 |
-
pitch_online_dir=None,
|
| 154 |
-
betabinomial_online_dir=None,
|
| 155 |
-
use_betabinomial_interpolator=True,
|
| 156 |
-
pitch_online_method='pyin',
|
| 157 |
-
**ignored):
|
| 158 |
-
|
| 159 |
-
# Expect a list of filenames
|
| 160 |
-
if type(audiopaths_and_text) is str:
|
| 161 |
-
audiopaths_and_text = [audiopaths_and_text]
|
| 162 |
-
|
| 163 |
-
self.dataset_path = dataset_path
|
| 164 |
-
self.audiopaths_and_text = load_filepaths_and_text(
|
| 165 |
-
dataset_path, audiopaths_and_text,
|
| 166 |
-
has_speakers=(n_speakers > 1))
|
| 167 |
-
self.load_mel_from_disk = load_mel_from_disk
|
| 168 |
-
if not load_mel_from_disk:
|
| 169 |
-
self.max_wav_value = max_wav_value
|
| 170 |
-
self.sampling_rate = sampling_rate
|
| 171 |
-
self.stft = layers.TacotronSTFT(
|
| 172 |
-
filter_length, hop_length, win_length,
|
| 173 |
-
n_mel_channels, sampling_rate, mel_fmin, mel_fmax)
|
| 174 |
-
self.load_pitch_from_disk = load_pitch_from_disk
|
| 175 |
-
|
| 176 |
-
self.prepend_space_to_text = prepend_space_to_text
|
| 177 |
-
self.append_space_to_text = append_space_to_text
|
| 178 |
-
|
| 179 |
-
assert p_arpabet == 0.0 or p_arpabet == 1.0, (
|
| 180 |
-
'Only 0.0 and 1.0 p_arpabet is currently supported. '
|
| 181 |
-
'Variable probability breaks caching of betabinomial matrices.')
|
| 182 |
-
|
| 183 |
-
self.tp = TextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet)
|
| 184 |
-
self.n_speakers = n_speakers
|
| 185 |
-
self.pitch_tmp_dir = pitch_online_dir
|
| 186 |
-
self.f0_method = pitch_online_method
|
| 187 |
-
self.betabinomial_tmp_dir = betabinomial_online_dir
|
| 188 |
-
self.use_betabinomial_interpolator = use_betabinomial_interpolator
|
| 189 |
-
|
| 190 |
-
if use_betabinomial_interpolator:
|
| 191 |
-
self.betabinomial_interpolator = BetaBinomialInterpolator()
|
| 192 |
-
|
| 193 |
-
expected_columns = (2 + int(load_pitch_from_disk) + (n_speakers > 1))
|
| 194 |
-
|
| 195 |
-
assert not (load_pitch_from_disk and self.pitch_tmp_dir is not None)
|
| 196 |
-
|
| 197 |
-
if len(self.audiopaths_and_text[0]) < expected_columns:
|
| 198 |
-
raise ValueError(f'Expected {expected_columns} columns in audiopaths file. '
|
| 199 |
-
'The format is <mel_or_wav>|[<pitch>|]<text>[|<speaker_id>]')
|
| 200 |
-
|
| 201 |
-
if len(self.audiopaths_and_text[0]) > expected_columns:
|
| 202 |
-
print('WARNING: Audiopaths file has more columns than expected')
|
| 203 |
-
|
| 204 |
-
to_tensor = lambda x: torch.Tensor([x]) if type(x) is float else x
|
| 205 |
-
self.pitch_mean = to_tensor(pitch_mean)
|
| 206 |
-
self.pitch_std = to_tensor(pitch_std)
|
| 207 |
-
|
| 208 |
-
def __getitem__(self, index):
|
| 209 |
-
# Separate filename and text
|
| 210 |
-
if self.n_speakers > 1:
|
| 211 |
-
audiopath, *extra, text, speaker = self.audiopaths_and_text[index]
|
| 212 |
-
speaker = int(speaker)
|
| 213 |
-
else:
|
| 214 |
-
audiopath, *extra, text = self.audiopaths_and_text[index]
|
| 215 |
-
speaker = None
|
| 216 |
-
|
| 217 |
-
mel = self.get_mel(audiopath)
|
| 218 |
-
text = self.get_text(text)
|
| 219 |
-
# print(text)
|
| 220 |
-
pitch = self.get_pitch(index, mel.size(-1))
|
| 221 |
-
energy = torch.norm(mel.float(), dim=0, p=2)
|
| 222 |
-
attn_prior = self.get_prior(index, mel.shape[1], text.shape[0])
|
| 223 |
-
|
| 224 |
-
assert pitch.size(-1) == mel.size(-1)
|
| 225 |
-
|
| 226 |
-
# No higher formants?
|
| 227 |
-
if len(pitch.size()) == 1:
|
| 228 |
-
pitch = pitch[None, :]
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
return (text, mel, len(text), pitch, energy, speaker, attn_prior,
|
| 232 |
-
audiopath)
|
| 233 |
-
|
| 234 |
-
def __len__(self):
|
| 235 |
-
return len(self.audiopaths_and_text)
|
| 236 |
-
|
| 237 |
-
def get_mel(self, filename):
|
| 238 |
-
if not self.load_mel_from_disk:
|
| 239 |
-
audio, sampling_rate = load_wav_to_torch(filename)
|
| 240 |
-
if sampling_rate != self.stft.sampling_rate:
|
| 241 |
-
raise ValueError("{} SR doesn't match target {} SR".format(
|
| 242 |
-
sampling_rate, self.stft.sampling_rate))
|
| 243 |
-
audio_norm = audio / self.max_wav_value
|
| 244 |
-
audio_norm = audio_norm.unsqueeze(0)
|
| 245 |
-
audio_norm = torch.autograd.Variable(audio_norm,
|
| 246 |
-
requires_grad=False)
|
| 247 |
-
melspec = self.stft.mel_spectrogram(audio_norm)
|
| 248 |
-
melspec = torch.squeeze(melspec, 0)
|
| 249 |
-
else:
|
| 250 |
-
melspec = torch.load(filename)
|
| 251 |
-
assert melspec.size(0) == self.stft.n_mel_channels, (
|
| 252 |
-
'Mel dimension mismatch: given {}, expected {}'.format(
|
| 253 |
-
melspec.size(0), self.stft.n_mel_channels))
|
| 254 |
-
|
| 255 |
-
################ Plotting mels ########################################
|
| 256 |
-
import matplotlib.pyplot as plt
|
| 257 |
-
# plt.imshow(melspec.detach().cpu().T,aspect="auto")
|
| 258 |
-
fig, ax1 = plt.subplots(ncols=1)
|
| 259 |
-
pos = ax1.imshow(melspec.cpu().numpy().T,aspect="auto")
|
| 260 |
-
fig.colorbar(pos, ax=ax1)
|
| 261 |
-
plt.show()
|
| 262 |
-
#######################################################################
|
| 263 |
-
|
| 264 |
-
return melspec
|
| 265 |
-
|
| 266 |
-
def get_text(self, text):
|
| 267 |
-
text = self.tp.encode_text(text)
|
| 268 |
-
space = [self.tp.encode_text("A A")[1]]
|
| 269 |
-
|
| 270 |
-
if self.prepend_space_to_text:
|
| 271 |
-
text = space + text
|
| 272 |
-
|
| 273 |
-
if self.append_space_to_text:
|
| 274 |
-
text = text + space
|
| 275 |
-
|
| 276 |
-
return torch.LongTensor(text)
|
| 277 |
-
|
| 278 |
-
def get_prior(self, index, mel_len, text_len):
|
| 279 |
-
|
| 280 |
-
if self.use_betabinomial_interpolator:
|
| 281 |
-
return torch.from_numpy(self.betabinomial_interpolator(mel_len,
|
| 282 |
-
text_len))
|
| 283 |
-
|
| 284 |
-
if self.betabinomial_tmp_dir is not None:
|
| 285 |
-
audiopath, *_ = self.audiopaths_and_text[index]
|
| 286 |
-
fname = Path(audiopath).relative_to(self.dataset_path)
|
| 287 |
-
fname = fname.with_suffix('.pt')
|
| 288 |
-
cached_fpath = Path(self.betabinomial_tmp_dir, fname)
|
| 289 |
-
|
| 290 |
-
if cached_fpath.is_file():
|
| 291 |
-
return torch.load(cached_fpath)
|
| 292 |
-
|
| 293 |
-
attn_prior = beta_binomial_prior_distribution(text_len, mel_len)
|
| 294 |
-
|
| 295 |
-
if self.betabinomial_tmp_dir is not None:
|
| 296 |
-
cached_fpath.parent.mkdir(parents=True, exist_ok=True)
|
| 297 |
-
torch.save(attn_prior, cached_fpath)
|
| 298 |
-
|
| 299 |
-
return attn_prior
|
| 300 |
-
|
| 301 |
-
def get_pitch(self, index, mel_len=None):
|
| 302 |
-
audiopath, *fields = self.audiopaths_and_text[index]
|
| 303 |
-
|
| 304 |
-
if self.n_speakers > 1:
|
| 305 |
-
spk = int(fields[-1])
|
| 306 |
-
else:
|
| 307 |
-
spk = 0
|
| 308 |
-
|
| 309 |
-
if self.load_pitch_from_disk:
|
| 310 |
-
pitchpath = fields[0]
|
| 311 |
-
pitch = torch.load(pitchpath)
|
| 312 |
-
if self.pitch_mean is not None:
|
| 313 |
-
assert self.pitch_std is not None
|
| 314 |
-
pitch = normalize_pitch(pitch, self.pitch_mean, self.pitch_std)
|
| 315 |
-
return pitch
|
| 316 |
-
|
| 317 |
-
if self.pitch_tmp_dir is not None:
|
| 318 |
-
fname = Path(audiopath).relative_to(self.dataset_path)
|
| 319 |
-
fname_method = fname.with_suffix('.pt')
|
| 320 |
-
cached_fpath = Path(self.pitch_tmp_dir, fname_method)
|
| 321 |
-
if cached_fpath.is_file():
|
| 322 |
-
return torch.load(cached_fpath)
|
| 323 |
-
|
| 324 |
-
# No luck so far - calculate
|
| 325 |
-
wav = audiopath
|
| 326 |
-
if not wav.endswith('.wav'):
|
| 327 |
-
wav = re.sub('/mels/', '/wavs/', wav)
|
| 328 |
-
wav = re.sub('.pt$', '.wav', wav)
|
| 329 |
-
|
| 330 |
-
pitch_mel = estimate_pitch(wav, mel_len, self.f0_method,
|
| 331 |
-
self.pitch_mean, self.pitch_std)
|
| 332 |
-
|
| 333 |
-
if self.pitch_tmp_dir is not None and not cached_fpath.is_file():
|
| 334 |
-
cached_fpath.parent.mkdir(parents=True, exist_ok=True)
|
| 335 |
-
torch.save(pitch_mel, cached_fpath)
|
| 336 |
-
|
| 337 |
-
return pitch_mel
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
class TTSCollate:
|
| 341 |
-
"""Zero-pads model inputs and targets based on number of frames per step"""
|
| 342 |
-
|
| 343 |
-
def __call__(self, batch):
|
| 344 |
-
"""Collate training batch from normalized text and mel-spec"""
|
| 345 |
-
# Right zero-pad all one-hot text sequences to max input length
|
| 346 |
-
input_lengths, ids_sorted_decreasing = torch.sort(
|
| 347 |
-
torch.LongTensor([len(x[0]) for x in batch]),
|
| 348 |
-
dim=0, descending=True)
|
| 349 |
-
max_input_len = input_lengths[0]
|
| 350 |
-
|
| 351 |
-
text_padded = torch.LongTensor(len(batch), max_input_len)
|
| 352 |
-
text_padded.zero_()
|
| 353 |
-
for i in range(len(ids_sorted_decreasing)):
|
| 354 |
-
text = batch[ids_sorted_decreasing[i]][0]
|
| 355 |
-
text_padded[i, :text.size(0)] = text
|
| 356 |
-
|
| 357 |
-
# Right zero-pad mel-spec
|
| 358 |
-
num_mels = batch[0][1].size(0)
|
| 359 |
-
max_target_len = max([x[1].size(1) for x in batch])
|
| 360 |
-
|
| 361 |
-
# Include mel padded and gate padded
|
| 362 |
-
mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
|
| 363 |
-
mel_padded.zero_()
|
| 364 |
-
output_lengths = torch.LongTensor(len(batch))
|
| 365 |
-
for i in range(len(ids_sorted_decreasing)):
|
| 366 |
-
mel = batch[ids_sorted_decreasing[i]][1]
|
| 367 |
-
mel_padded[i, :, :mel.size(1)] = mel
|
| 368 |
-
output_lengths[i] = mel.size(1)
|
| 369 |
-
|
| 370 |
-
n_formants = batch[0][3].shape[0]
|
| 371 |
-
pitch_padded = torch.zeros(mel_padded.size(0), n_formants,
|
| 372 |
-
mel_padded.size(2), dtype=batch[0][3].dtype)
|
| 373 |
-
energy_padded = torch.zeros_like(pitch_padded[:, 0, :])
|
| 374 |
-
|
| 375 |
-
for i in range(len(ids_sorted_decreasing)):
|
| 376 |
-
pitch = batch[ids_sorted_decreasing[i]][3]
|
| 377 |
-
energy = batch[ids_sorted_decreasing[i]][4]
|
| 378 |
-
pitch_padded[i, :, :pitch.shape[1]] = pitch
|
| 379 |
-
energy_padded[i, :energy.shape[0]] = energy
|
| 380 |
-
|
| 381 |
-
if batch[0][5] is not None:
|
| 382 |
-
speaker = torch.zeros_like(input_lengths)
|
| 383 |
-
for i in range(len(ids_sorted_decreasing)):
|
| 384 |
-
speaker[i] = batch[ids_sorted_decreasing[i]][5]
|
| 385 |
-
else:
|
| 386 |
-
speaker = None
|
| 387 |
-
|
| 388 |
-
attn_prior_padded = torch.zeros(len(batch), max_target_len,
|
| 389 |
-
max_input_len)
|
| 390 |
-
attn_prior_padded.zero_()
|
| 391 |
-
for i in range(len(ids_sorted_decreasing)):
|
| 392 |
-
prior = batch[ids_sorted_decreasing[i]][6]
|
| 393 |
-
attn_prior_padded[i, :prior.size(0), :prior.size(1)] = prior
|
| 394 |
-
|
| 395 |
-
# Count number of items - characters in text
|
| 396 |
-
len_x = [x[2] for x in batch]
|
| 397 |
-
len_x = torch.Tensor(len_x)
|
| 398 |
-
|
| 399 |
-
audiopaths = [batch[i][7] for i in ids_sorted_decreasing]
|
| 400 |
-
|
| 401 |
-
return (text_padded, input_lengths, mel_padded, output_lengths, len_x,
|
| 402 |
-
pitch_padded, energy_padded, speaker, attn_prior_padded,
|
| 403 |
-
audiopaths)
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
def batch_to_gpu(batch):
|
| 407 |
-
(text_padded, input_lengths, mel_padded, output_lengths, len_x,
|
| 408 |
-
pitch_padded, energy_padded, speaker, attn_prior, audiopaths) = batch
|
| 409 |
-
|
| 410 |
-
text_padded = to_gpu(text_padded).long()
|
| 411 |
-
input_lengths = to_gpu(input_lengths).long()
|
| 412 |
-
mel_padded = to_gpu(mel_padded).float()
|
| 413 |
-
output_lengths = to_gpu(output_lengths).long()
|
| 414 |
-
pitch_padded = to_gpu(pitch_padded).float()
|
| 415 |
-
energy_padded = to_gpu(energy_padded).float()
|
| 416 |
-
attn_prior = to_gpu(attn_prior).float()
|
| 417 |
-
if speaker is not None:
|
| 418 |
-
speaker = to_gpu(speaker).long()
|
| 419 |
-
|
| 420 |
-
# Alignments act as both inputs and targets - pass shallow copies
|
| 421 |
-
x = [text_padded, input_lengths, mel_padded, output_lengths,
|
| 422 |
-
pitch_padded, energy_padded, speaker, attn_prior, audiopaths]
|
| 423 |
-
y = [mel_padded, input_lengths, output_lengths]
|
| 424 |
-
len_x = torch.sum(output_lengths)
|
| 425 |
-
return (x, y, len_x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fastpitch/data_function_model_py.zip
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:7977160d12775529ba3426181093c8bca7927e52024f6d4faa91e1a0e53ef008
|
| 3 |
-
size 9564
|
|
|
|
|
|
|
|
|
|
|
|
fastpitch/utils_trainplot_transformers.zip
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5d3b72d713798a0552c0939fded67cb89655e12cc406a6fe793fcc7c9f63456a
|
| 3 |
-
size 16365
|
|
|
|
|
|
|
|
|
|
|
|
fastpitch/utils_trainplot_transformers/train_1_with_plot.py
DELETED
|
@@ -1,591 +0,0 @@
|
|
| 1 |
-
# *****************************************************************************
|
| 2 |
-
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Redistribution and use in source and binary forms, with or without
|
| 5 |
-
# modification, are permitted provided that the following conditions are met:
|
| 6 |
-
# * Redistributions of source code must retain the above copyright
|
| 7 |
-
# notice, this list of conditions and the following disclaimer.
|
| 8 |
-
# * Redistributions in binary form must reproduce the above copyright
|
| 9 |
-
# notice, this list of conditions and the following disclaimer in the
|
| 10 |
-
# documentation and/or other materials provided with the distribution.
|
| 11 |
-
# * Neither the name of the NVIDIA CORPORATION nor the
|
| 12 |
-
# names of its contributors may be used to endorse or promote products
|
| 13 |
-
# derived from this software without specific prior written permission.
|
| 14 |
-
#
|
| 15 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 16 |
-
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 17 |
-
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 18 |
-
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
| 19 |
-
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
| 20 |
-
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
| 21 |
-
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
| 22 |
-
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 23 |
-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
| 24 |
-
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 25 |
-
#
|
| 26 |
-
# *****************************************************************************
|
| 27 |
-
|
| 28 |
-
import argparse
|
| 29 |
-
import copy
|
| 30 |
-
import os
|
| 31 |
-
import time
|
| 32 |
-
from collections import defaultdict, OrderedDict
|
| 33 |
-
from itertools import cycle
|
| 34 |
-
|
| 35 |
-
import numpy as np
|
| 36 |
-
import torch
|
| 37 |
-
import torch.distributed as dist
|
| 38 |
-
import amp_C
|
| 39 |
-
from apex.optimizers import FusedAdam, FusedLAMB
|
| 40 |
-
from torch.nn.parallel import DistributedDataParallel
|
| 41 |
-
from torch.utils.data import DataLoader
|
| 42 |
-
from torch.utils.data.distributed import DistributedSampler
|
| 43 |
-
|
| 44 |
-
import common.tb_dllogger as logger
|
| 45 |
-
import models
|
| 46 |
-
from common.tb_dllogger import log
|
| 47 |
-
from common.repeated_dataloader import (RepeatedDataLoader,
|
| 48 |
-
RepeatedDistributedSampler)
|
| 49 |
-
from common.text import cmudict
|
| 50 |
-
from common.utils import BenchmarkStats, Checkpointer, prepare_tmp
|
| 51 |
-
from fastpitch.attn_loss_function import AttentionBinarizationLoss
|
| 52 |
-
from fastpitch.data_function import batch_to_gpu, TTSCollate, TTSDataset
|
| 53 |
-
from fastpitch.loss_function import FastPitchLoss
|
| 54 |
-
|
| 55 |
-
import matplotlib.pyplot as plt
|
| 56 |
-
|
| 57 |
-
def parse_args(parser):
|
| 58 |
-
parser.add_argument('-o', '--output', type=str, required=True,
|
| 59 |
-
help='Directory to save checkpoints')
|
| 60 |
-
parser.add_argument('-d', '--dataset-path', type=str, default='./',
|
| 61 |
-
help='Path to dataset')
|
| 62 |
-
parser.add_argument('--log-file', type=str, default=None,
|
| 63 |
-
help='Path to a DLLogger log file')
|
| 64 |
-
|
| 65 |
-
train = parser.add_argument_group('training setup')
|
| 66 |
-
train.add_argument('--epochs', type=int, required=True,
|
| 67 |
-
help='Number of total epochs to run')
|
| 68 |
-
train.add_argument('--epochs-per-checkpoint', type=int, default=50,
|
| 69 |
-
help='Number of epochs per checkpoint')
|
| 70 |
-
train.add_argument('--checkpoint-path', type=str, default=None,
|
| 71 |
-
help='Checkpoint path to resume training')
|
| 72 |
-
train.add_argument('--keep-milestones', default=list(range(100, 1000, 100)),
|
| 73 |
-
type=int, nargs='+',
|
| 74 |
-
help='Milestone checkpoints to keep from removing')
|
| 75 |
-
train.add_argument('--resume', action='store_true',
|
| 76 |
-
help='Resume training from the last checkpoint')
|
| 77 |
-
train.add_argument('--seed', type=int, default=1234,
|
| 78 |
-
help='Seed for PyTorch random number generators')
|
| 79 |
-
train.add_argument('--amp', action='store_true',
|
| 80 |
-
help='Enable AMP')
|
| 81 |
-
train.add_argument('--cuda', action='store_true',
|
| 82 |
-
help='Run on GPU using CUDA')
|
| 83 |
-
train.add_argument('--cudnn-benchmark', action='store_true',
|
| 84 |
-
help='Enable cudnn benchmark mode')
|
| 85 |
-
train.add_argument('--ema-decay', type=float, default=0,
|
| 86 |
-
help='Discounting factor for training weights EMA')
|
| 87 |
-
train.add_argument('--grad-accumulation', type=int, default=1,
|
| 88 |
-
help='Training steps to accumulate gradients for')
|
| 89 |
-
train.add_argument('--kl-loss-start-epoch', type=int, default=250,
|
| 90 |
-
help='Start adding the hard attention loss term')
|
| 91 |
-
train.add_argument('--kl-loss-warmup-epochs', type=int, default=100,
|
| 92 |
-
help='Gradually increase the hard attention loss term')
|
| 93 |
-
train.add_argument('--kl-loss-weight', type=float, default=1.0,
|
| 94 |
-
help='Gradually increase the hard attention loss term')
|
| 95 |
-
train.add_argument('--benchmark-epochs-num', type=int, default=20,
|
| 96 |
-
help='Number of epochs for calculating final stats')
|
| 97 |
-
train.add_argument('--validation-freq', type=int, default=1,
|
| 98 |
-
help='Validate every N epochs to use less compute')
|
| 99 |
-
|
| 100 |
-
opt = parser.add_argument_group('optimization setup')
|
| 101 |
-
opt.add_argument('--optimizer', type=str, default='lamb',
|
| 102 |
-
help='Optimization algorithm')
|
| 103 |
-
opt.add_argument('-lr', '--learning-rate', type=float, required=True,
|
| 104 |
-
help='Learing rate')
|
| 105 |
-
opt.add_argument('--weight-decay', default=1e-6, type=float,
|
| 106 |
-
help='Weight decay')
|
| 107 |
-
opt.add_argument('--grad-clip-thresh', default=1000.0, type=float,
|
| 108 |
-
help='Clip threshold for gradients')
|
| 109 |
-
opt.add_argument('-bs', '--batch-size', type=int, required=True,
|
| 110 |
-
help='Batch size per GPU')
|
| 111 |
-
opt.add_argument('--warmup-steps', type=int, default=1000,
|
| 112 |
-
help='Number of steps for lr warmup')
|
| 113 |
-
opt.add_argument('--dur-predictor-loss-scale', type=float,
|
| 114 |
-
default=1.0, help='Rescale duration predictor loss')
|
| 115 |
-
opt.add_argument('--pitch-predictor-loss-scale', type=float,
|
| 116 |
-
default=1.0, help='Rescale pitch predictor loss')
|
| 117 |
-
opt.add_argument('--attn-loss-scale', type=float,
|
| 118 |
-
default=1.0, help='Rescale alignment loss')
|
| 119 |
-
|
| 120 |
-
data = parser.add_argument_group('dataset parameters')
|
| 121 |
-
data.add_argument('--training-files', type=str, nargs='*', required=True,
|
| 122 |
-
help='Paths to training filelists.')
|
| 123 |
-
data.add_argument('--validation-files', type=str, nargs='*',
|
| 124 |
-
required=True, help='Paths to validation filelists')
|
| 125 |
-
data.add_argument('--text-cleaners', nargs='*',
|
| 126 |
-
default=['english_cleaners'], type=str,
|
| 127 |
-
help='Type of text cleaners for input text')
|
| 128 |
-
data.add_argument('--symbol-set', type=str, default='english_basic',
|
| 129 |
-
help='Define symbol set for input text')
|
| 130 |
-
data.add_argument('--p-arpabet', type=float, default=0.0,
|
| 131 |
-
help='Probability of using arpabets instead of graphemes '
|
| 132 |
-
'for each word; set 0 for pure grapheme training')
|
| 133 |
-
data.add_argument('--heteronyms-path', type=str, default='cmudict/heteronyms',
|
| 134 |
-
help='Path to the list of heteronyms')
|
| 135 |
-
data.add_argument('--cmudict-path', type=str, default='cmudict/cmudict-0.7b',
|
| 136 |
-
help='Path to the pronouncing dictionary')
|
| 137 |
-
data.add_argument('--prepend-space-to-text', action='store_true',
|
| 138 |
-
help='Capture leading silence with a space token')
|
| 139 |
-
data.add_argument('--append-space-to-text', action='store_true',
|
| 140 |
-
help='Capture trailing silence with a space token')
|
| 141 |
-
data.add_argument('--num-workers', type=int, default=2, # 6
|
| 142 |
-
help='Subprocesses for train and val DataLoaders')
|
| 143 |
-
data.add_argument('--trainloader-repeats', type=int, default=100,
|
| 144 |
-
help='Repeats the dataset to prolong epochs')
|
| 145 |
-
|
| 146 |
-
cond = parser.add_argument_group('data for conditioning')
|
| 147 |
-
cond.add_argument('--n-speakers', type=int, default=1,
|
| 148 |
-
help='Number of speakers in the dataset. '
|
| 149 |
-
'n_speakers > 1 enables speaker embeddings')
|
| 150 |
-
cond.add_argument('--load-pitch-from-disk', action='store_true',
|
| 151 |
-
help='Use pitch cached on disk with prepare_dataset.py')
|
| 152 |
-
cond.add_argument('--pitch-online-method', default='pyin',
|
| 153 |
-
choices=['pyin'],
|
| 154 |
-
help='Calculate pitch on the fly during trainig')
|
| 155 |
-
cond.add_argument('--pitch-online-dir', type=str, default=None,
|
| 156 |
-
help='A directory for storing pitch calculated on-line')
|
| 157 |
-
cond.add_argument('--pitch-mean', type=float, default=125.626816, #default=214.72203,
|
| 158 |
-
help='Normalization value for pitch')
|
| 159 |
-
cond.add_argument('--pitch-std', type=float, default=37.52, #default=65.72038,
|
| 160 |
-
help='Normalization value for pitch')
|
| 161 |
-
cond.add_argument('--load-mel-from-disk', action='store_true',
|
| 162 |
-
help='Use mel-spectrograms cache on the disk') # XXX
|
| 163 |
-
|
| 164 |
-
audio = parser.add_argument_group('audio parameters')
|
| 165 |
-
audio.add_argument('--max-wav-value', default=32768.0, type=float,
|
| 166 |
-
help='Maximum audiowave value')
|
| 167 |
-
audio.add_argument('--sampling-rate', default=22050, type=int,
|
| 168 |
-
help='Sampling rate')
|
| 169 |
-
audio.add_argument('--filter-length', default=1024, type=int,
|
| 170 |
-
help='Filter length')
|
| 171 |
-
audio.add_argument('--hop-length', default=256, type=int,
|
| 172 |
-
help='Hop (stride) length')
|
| 173 |
-
audio.add_argument('--win-length', default=1024, type=int,
|
| 174 |
-
help='Window length')
|
| 175 |
-
audio.add_argument('--mel-fmin', default=0.0, type=float,
|
| 176 |
-
help='Minimum mel frequency')
|
| 177 |
-
audio.add_argument('--mel-fmax', default=8000.0, type=float,
|
| 178 |
-
help='Maximum mel frequency')
|
| 179 |
-
|
| 180 |
-
dist = parser.add_argument_group('distributed setup')
|
| 181 |
-
dist.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0),
|
| 182 |
-
help='Rank of the process for multiproc; do not set manually')
|
| 183 |
-
dist.add_argument('--world_size', type=int, default=os.getenv('WORLD_SIZE', 1),
|
| 184 |
-
help='Number of processes for multiproc; do not set manually')
|
| 185 |
-
return parser
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
def reduce_tensor(tensor, num_gpus):
|
| 189 |
-
rt = tensor.clone()
|
| 190 |
-
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
| 191 |
-
return rt.true_divide(num_gpus)
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
def init_distributed(args, world_size, rank):
|
| 195 |
-
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
|
| 196 |
-
print("Initializing distributed training")
|
| 197 |
-
|
| 198 |
-
# Set cuda device so everything is done on the right GPU.
|
| 199 |
-
torch.cuda.set_device(rank % torch.cuda.device_count())
|
| 200 |
-
|
| 201 |
-
# Initialize distributed communication
|
| 202 |
-
dist.init_process_group(backend=('nccl' if args.cuda else 'gloo'),
|
| 203 |
-
init_method='env://')
|
| 204 |
-
print("Done initializing distributed training")
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
def validate(model, epoch, total_iter, criterion, val_loader, distributed_run,
|
| 208 |
-
batch_to_gpu, local_rank, ema=False):
|
| 209 |
-
was_training = model.training
|
| 210 |
-
model.eval()
|
| 211 |
-
|
| 212 |
-
tik = time.perf_counter()
|
| 213 |
-
with torch.no_grad():
|
| 214 |
-
val_meta = defaultdict(float)
|
| 215 |
-
val_num_frames = 0
|
| 216 |
-
for i, batch in enumerate(val_loader):
|
| 217 |
-
x, y, num_frames = batch_to_gpu(batch)
|
| 218 |
-
y_pred = model(x)
|
| 219 |
-
loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum')
|
| 220 |
-
|
| 221 |
-
if distributed_run:
|
| 222 |
-
for k, v in meta.items():
|
| 223 |
-
val_meta[k] += reduce_tensor(v, 1)
|
| 224 |
-
val_num_frames += reduce_tensor(num_frames.data, 1).item()
|
| 225 |
-
else:
|
| 226 |
-
for k, v in meta.items():
|
| 227 |
-
val_meta[k] += v
|
| 228 |
-
val_num_frames += num_frames.item()
|
| 229 |
-
|
| 230 |
-
# NOTE: ugly patch to visualize the first utterance of the validation corpus.
|
| 231 |
-
# The goal is to determine if the training is progressing properly
|
| 232 |
-
if (i == 0) and (local_rank == 0) and (not ema):
|
| 233 |
-
# Plot some debug information
|
| 234 |
-
fig, axs = plt.subplots(2, 2, figsize=(21,14))
|
| 235 |
-
|
| 236 |
-
# - Mel-spectrogram
|
| 237 |
-
pred_mel = y_pred[0][0, :, :].cpu().detach().numpy().astype(np.float32).T
|
| 238 |
-
orig_mel = y[0][0, :, :].cpu().detach().numpy().astype(np.float32)
|
| 239 |
-
axs[0,0].imshow(orig_mel, aspect='auto', origin='lower', interpolation='nearest')
|
| 240 |
-
axs[1,0].imshow(pred_mel, aspect='auto', origin='lower', interpolation='nearest')
|
| 241 |
-
|
| 242 |
-
# Prosody
|
| 243 |
-
f0_pred = y_pred[4][0, :].cpu().detach().numpy().astype(np.float32)
|
| 244 |
-
f0_ori = y_pred[5][0, :].cpu().detach().numpy().astype(np.float32)
|
| 245 |
-
axs[1,1].plot(f0_ori)
|
| 246 |
-
axs[1,1].plot(f0_pred)
|
| 247 |
-
|
| 248 |
-
# # Duration
|
| 249 |
-
# att_pred = y_pred[2][0, :].cpu().detach().numpy().astype(np.float32)
|
| 250 |
-
# att_ori = x[7][0,:].cpu().detach().numpy().astype(np.float32)
|
| 251 |
-
# axs[0,1].imshow(att_ori, aspect='auto', origin='lower', interpolation='nearest')
|
| 252 |
-
|
| 253 |
-
if not os.path.exists("debug_epoch/"):
|
| 254 |
-
os.makedirs("debug_epoch_laila/")
|
| 255 |
-
|
| 256 |
-
fig.savefig(f'debug_epoch/{epoch:06d}.png', bbox_inches='tight')
|
| 257 |
-
|
| 258 |
-
val_meta = {k: v / len(val_loader.dataset) for k, v in val_meta.items()}
|
| 259 |
-
|
| 260 |
-
val_meta['took'] = time.perf_counter() - tik
|
| 261 |
-
|
| 262 |
-
log((epoch,) if epoch is not None else (), tb_total_steps=total_iter,
|
| 263 |
-
subset='val_ema' if ema else 'val',
|
| 264 |
-
data=OrderedDict([
|
| 265 |
-
('loss', val_meta['loss'].item()),
|
| 266 |
-
('mel_loss', val_meta['mel_loss'].item()),
|
| 267 |
-
('frames/s', val_num_frames / val_meta['took']),
|
| 268 |
-
('took', val_meta['took'])]),
|
| 269 |
-
)
|
| 270 |
-
|
| 271 |
-
if was_training:
|
| 272 |
-
model.train()
|
| 273 |
-
return val_meta
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None):
|
| 277 |
-
if warmup_iters == 0:
|
| 278 |
-
scale = 1.0
|
| 279 |
-
elif total_iter > warmup_iters:
|
| 280 |
-
scale = 1. / (total_iter ** 0.5)
|
| 281 |
-
else:
|
| 282 |
-
scale = total_iter / (warmup_iters ** 1.5)
|
| 283 |
-
|
| 284 |
-
for param_group in opt.param_groups:
|
| 285 |
-
param_group['lr'] = learning_rate * scale
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
def apply_ema_decay(model, ema_model, decay):
|
| 289 |
-
if not decay:
|
| 290 |
-
return
|
| 291 |
-
st = model.state_dict()
|
| 292 |
-
add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
|
| 293 |
-
for k, v in ema_model.state_dict().items():
|
| 294 |
-
if add_module and not k.startswith('module.'):
|
| 295 |
-
k = 'module.' + k
|
| 296 |
-
v.copy_(decay * v + (1 - decay) * st[k])
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
def init_multi_tensor_ema(model, ema_model):
|
| 300 |
-
model_weights = list(model.state_dict().values())
|
| 301 |
-
ema_model_weights = list(ema_model.state_dict().values())
|
| 302 |
-
ema_overflow_buf = torch.cuda.IntTensor([0])
|
| 303 |
-
return model_weights, ema_model_weights, ema_overflow_buf
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
|
| 307 |
-
amp_C.multi_tensor_axpby(
|
| 308 |
-
65536, overflow_buf, [ema_weights, model_weights, ema_weights],
|
| 309 |
-
decay, 1-decay, -1)
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
def main():
|
| 313 |
-
parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
|
| 314 |
-
allow_abbrev=False)
|
| 315 |
-
parser = parse_args(parser)
|
| 316 |
-
args, _ = parser.parse_known_args()
|
| 317 |
-
|
| 318 |
-
if args.p_arpabet > 0.0:
|
| 319 |
-
cmudict.initialize(args.cmudict_path, args.heteronyms_path)
|
| 320 |
-
|
| 321 |
-
distributed_run = args.world_size > 1
|
| 322 |
-
|
| 323 |
-
torch.manual_seed(args.seed + args.local_rank)
|
| 324 |
-
np.random.seed(args.seed + args.local_rank)
|
| 325 |
-
|
| 326 |
-
if args.local_rank == 0:
|
| 327 |
-
if not os.path.exists(args.output):
|
| 328 |
-
os.makedirs(args.output)
|
| 329 |
-
|
| 330 |
-
log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
|
| 331 |
-
tb_subsets = ['train', 'val']
|
| 332 |
-
if args.ema_decay > 0.0:
|
| 333 |
-
tb_subsets.append('val_ema')
|
| 334 |
-
|
| 335 |
-
logger.init(log_fpath, args.output, enabled=(args.local_rank == 0),
|
| 336 |
-
tb_subsets=tb_subsets)
|
| 337 |
-
logger.parameters(vars(args), tb_subset='train')
|
| 338 |
-
|
| 339 |
-
parser = models.parse_model_args('FastPitch', parser)
|
| 340 |
-
args, unk_args = parser.parse_known_args()
|
| 341 |
-
if len(unk_args) > 0:
|
| 342 |
-
raise ValueError(f'Invalid options {unk_args}')
|
| 343 |
-
|
| 344 |
-
torch.backends.cudnn.benchmark = args.cudnn_benchmark
|
| 345 |
-
|
| 346 |
-
if distributed_run:
|
| 347 |
-
init_distributed(args, args.world_size, args.local_rank)
|
| 348 |
-
else:
|
| 349 |
-
if args.trainloader_repeats > 1:
|
| 350 |
-
print('WARNING: Disabled --trainloader-repeats, supported only for'
|
| 351 |
-
' multi-GPU data loading.')
|
| 352 |
-
args.trainloader_repeats = 1
|
| 353 |
-
|
| 354 |
-
device = torch.device('cuda' if args.cuda else 'cpu')
|
| 355 |
-
model_config = models.get_model_config('FastPitch', args)
|
| 356 |
-
model = models.get_model('FastPitch', model_config, device)
|
| 357 |
-
|
| 358 |
-
attention_kl_loss = AttentionBinarizationLoss()
|
| 359 |
-
|
| 360 |
-
# Store pitch mean/std as params to translate from Hz during inference
|
| 361 |
-
model.pitch_mean[0] = args.pitch_mean
|
| 362 |
-
model.pitch_std[0] = args.pitch_std
|
| 363 |
-
|
| 364 |
-
kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9,
|
| 365 |
-
weight_decay=args.weight_decay)
|
| 366 |
-
if args.optimizer == 'adam':
|
| 367 |
-
optimizer = FusedAdam(model.parameters(), **kw)
|
| 368 |
-
# optimizer = torch.optim.Adam(model.parameters(), **kw)
|
| 369 |
-
elif args.optimizer == 'lamb':
|
| 370 |
-
|
| 371 |
-
optimizer = FusedLAMB(model.parameters(), **kw)
|
| 372 |
-
# optimizer = torch.optim.Adam(model.parameters(), **kw)
|
| 373 |
-
else:
|
| 374 |
-
raise ValueError
|
| 375 |
-
|
| 376 |
-
scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
| 377 |
-
|
| 378 |
-
if args.ema_decay > 0:
|
| 379 |
-
ema_model = copy.deepcopy(model)
|
| 380 |
-
else:
|
| 381 |
-
ema_model = None
|
| 382 |
-
|
| 383 |
-
if distributed_run:
|
| 384 |
-
model = DistributedDataParallel(
|
| 385 |
-
model, device_ids=[args.local_rank], output_device=args.local_rank,
|
| 386 |
-
find_unused_parameters=True)
|
| 387 |
-
|
| 388 |
-
train_state = {'epoch': 1, 'total_iter': 1}
|
| 389 |
-
checkpointer = Checkpointer(args.output, args.keep_milestones)
|
| 390 |
-
|
| 391 |
-
checkpointer.maybe_load(model, optimizer, scaler, train_state, args,
|
| 392 |
-
ema_model)
|
| 393 |
-
|
| 394 |
-
start_epoch = train_state['epoch']
|
| 395 |
-
total_iter = train_state['total_iter']
|
| 396 |
-
|
| 397 |
-
criterion = FastPitchLoss(
|
| 398 |
-
dur_predictor_loss_scale=args.dur_predictor_loss_scale,
|
| 399 |
-
pitch_predictor_loss_scale=args.pitch_predictor_loss_scale,
|
| 400 |
-
attn_loss_scale=args.attn_loss_scale)
|
| 401 |
-
|
| 402 |
-
collate_fn = TTSCollate()
|
| 403 |
-
|
| 404 |
-
if args.local_rank == 0:
|
| 405 |
-
prepare_tmp(args.pitch_online_dir)
|
| 406 |
-
|
| 407 |
-
trainset = TTSDataset(audiopaths_and_text=args.training_files, **vars(args))
|
| 408 |
-
valset = TTSDataset(audiopaths_and_text=args.validation_files, **vars(args))
|
| 409 |
-
|
| 410 |
-
if distributed_run:
|
| 411 |
-
train_sampler = RepeatedDistributedSampler(args.trainloader_repeats,
|
| 412 |
-
trainset, drop_last=True)
|
| 413 |
-
val_sampler = DistributedSampler(valset)
|
| 414 |
-
shuffle = False
|
| 415 |
-
else:
|
| 416 |
-
train_sampler, val_sampler, shuffle = None, None, False ########### was True
|
| 417 |
-
|
| 418 |
-
# 4 workers are optimal on DGX-1 (from epoch 2 onwards)
|
| 419 |
-
kw = {'num_workers': args.num_workers, 'batch_size': args.batch_size,
|
| 420 |
-
'collate_fn': collate_fn}
|
| 421 |
-
train_loader = RepeatedDataLoader(args.trainloader_repeats, trainset,
|
| 422 |
-
shuffle=shuffle, drop_last=True,
|
| 423 |
-
sampler=train_sampler, pin_memory=True,
|
| 424 |
-
persistent_workers=True, **kw)
|
| 425 |
-
val_loader = DataLoader(valset, shuffle=False, sampler=val_sampler,
|
| 426 |
-
pin_memory=False, **kw)
|
| 427 |
-
if args.ema_decay:
|
| 428 |
-
mt_ema_params = init_multi_tensor_ema(model, ema_model)
|
| 429 |
-
|
| 430 |
-
model.train()
|
| 431 |
-
bmark_stats = BenchmarkStats()
|
| 432 |
-
|
| 433 |
-
torch.cuda.synchronize()
|
| 434 |
-
for epoch in range(start_epoch, args.epochs + 1):
|
| 435 |
-
epoch_start_time = time.perf_counter()
|
| 436 |
-
|
| 437 |
-
epoch_loss = 0.0
|
| 438 |
-
epoch_mel_loss = 0.0
|
| 439 |
-
epoch_num_frames = 0
|
| 440 |
-
epoch_frames_per_sec = 0.0
|
| 441 |
-
|
| 442 |
-
if distributed_run:
|
| 443 |
-
train_loader.sampler.set_epoch(epoch)
|
| 444 |
-
|
| 445 |
-
iter_loss = 0
|
| 446 |
-
iter_num_frames = 0
|
| 447 |
-
iter_meta = {}
|
| 448 |
-
iter_start_time = time.perf_counter()
|
| 449 |
-
|
| 450 |
-
epoch_iter = 1
|
| 451 |
-
for batch, accum_step in zip(train_loader,
|
| 452 |
-
cycle(range(1, args.grad_accumulation + 1))):
|
| 453 |
-
if accum_step == 1:
|
| 454 |
-
adjust_learning_rate(total_iter, optimizer, args.learning_rate,
|
| 455 |
-
args.warmup_steps)
|
| 456 |
-
|
| 457 |
-
model.zero_grad(set_to_none=True)
|
| 458 |
-
|
| 459 |
-
x, y, num_frames = batch_to_gpu(batch)
|
| 460 |
-
|
| 461 |
-
with torch.cuda.amp.autocast(enabled=args.amp):
|
| 462 |
-
y_pred = model(x)
|
| 463 |
-
loss, meta = criterion(y_pred, y)
|
| 464 |
-
|
| 465 |
-
if (args.kl_loss_start_epoch is not None
|
| 466 |
-
and epoch >= args.kl_loss_start_epoch):
|
| 467 |
-
|
| 468 |
-
if args.kl_loss_start_epoch == epoch and epoch_iter == 1:
|
| 469 |
-
print('Begin hard_attn loss')
|
| 470 |
-
|
| 471 |
-
_, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred
|
| 472 |
-
binarization_loss = attention_kl_loss(attn_hard, attn_soft)
|
| 473 |
-
kl_weight = min((epoch - args.kl_loss_start_epoch) / args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight
|
| 474 |
-
meta['kl_loss'] = binarization_loss.clone().detach() * kl_weight
|
| 475 |
-
loss += kl_weight * binarization_loss
|
| 476 |
-
|
| 477 |
-
else:
|
| 478 |
-
meta['kl_loss'] = torch.zeros_like(loss)
|
| 479 |
-
kl_weight = 0
|
| 480 |
-
binarization_loss = 0
|
| 481 |
-
|
| 482 |
-
loss /= args.grad_accumulation
|
| 483 |
-
|
| 484 |
-
meta = {k: v / args.grad_accumulation
|
| 485 |
-
for k, v in meta.items()}
|
| 486 |
-
|
| 487 |
-
if args.amp:
|
| 488 |
-
scaler.scale(loss).backward()
|
| 489 |
-
else:
|
| 490 |
-
loss.backward()
|
| 491 |
-
|
| 492 |
-
if distributed_run:
|
| 493 |
-
reduced_loss = reduce_tensor(loss.data, args.world_size).item()
|
| 494 |
-
reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
|
| 495 |
-
meta = {k: reduce_tensor(v, args.world_size) for k, v in meta.items()}
|
| 496 |
-
else:
|
| 497 |
-
reduced_loss = loss.item()
|
| 498 |
-
reduced_num_frames = num_frames.item()
|
| 499 |
-
if np.isnan(reduced_loss):
|
| 500 |
-
raise Exception("loss is NaN")
|
| 501 |
-
|
| 502 |
-
iter_loss += reduced_loss
|
| 503 |
-
iter_num_frames += reduced_num_frames
|
| 504 |
-
iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}
|
| 505 |
-
|
| 506 |
-
if accum_step % args.grad_accumulation == 0:
|
| 507 |
-
|
| 508 |
-
logger.log_grads_tb(total_iter, model)
|
| 509 |
-
if args.amp:
|
| 510 |
-
scaler.unscale_(optimizer)
|
| 511 |
-
torch.nn.utils.clip_grad_norm_(
|
| 512 |
-
model.parameters(), args.grad_clip_thresh)
|
| 513 |
-
scaler.step(optimizer)
|
| 514 |
-
scaler.update()
|
| 515 |
-
else:
|
| 516 |
-
torch.nn.utils.clip_grad_norm_(
|
| 517 |
-
model.parameters(), args.grad_clip_thresh)
|
| 518 |
-
optimizer.step()
|
| 519 |
-
|
| 520 |
-
if args.ema_decay > 0.0:
|
| 521 |
-
apply_multi_tensor_ema(args.ema_decay, *mt_ema_params)
|
| 522 |
-
|
| 523 |
-
iter_mel_loss = iter_meta['mel_loss'].item()
|
| 524 |
-
iter_kl_loss = iter_meta['kl_loss'].item()
|
| 525 |
-
iter_time = time.perf_counter() - iter_start_time
|
| 526 |
-
epoch_frames_per_sec += iter_num_frames / iter_time
|
| 527 |
-
epoch_loss += iter_loss
|
| 528 |
-
epoch_num_frames += iter_num_frames
|
| 529 |
-
epoch_mel_loss += iter_mel_loss
|
| 530 |
-
|
| 531 |
-
num_iters = len(train_loader) // args.grad_accumulation
|
| 532 |
-
log((epoch, epoch_iter, num_iters), tb_total_steps=total_iter,
|
| 533 |
-
subset='train', data=OrderedDict([
|
| 534 |
-
('loss', iter_loss),
|
| 535 |
-
('mel_loss', iter_mel_loss),
|
| 536 |
-
('kl_loss', iter_kl_loss),
|
| 537 |
-
('kl_weight', kl_weight),
|
| 538 |
-
('frames/s', iter_num_frames / iter_time),
|
| 539 |
-
('took', iter_time),
|
| 540 |
-
('lrate', optimizer.param_groups[0]['lr'])]),
|
| 541 |
-
)
|
| 542 |
-
|
| 543 |
-
iter_loss = 0
|
| 544 |
-
iter_num_frames = 0
|
| 545 |
-
iter_meta = {}
|
| 546 |
-
iter_start_time = time.perf_counter()
|
| 547 |
-
|
| 548 |
-
if epoch_iter == num_iters:
|
| 549 |
-
break
|
| 550 |
-
epoch_iter += 1
|
| 551 |
-
total_iter += 1
|
| 552 |
-
|
| 553 |
-
# Finished epoch
|
| 554 |
-
epoch_loss /= epoch_iter
|
| 555 |
-
epoch_mel_loss /= epoch_iter
|
| 556 |
-
epoch_time = time.perf_counter() - epoch_start_time
|
| 557 |
-
|
| 558 |
-
log((epoch,), tb_total_steps=None, subset='train_avg',
|
| 559 |
-
data=OrderedDict([
|
| 560 |
-
('loss', epoch_loss),
|
| 561 |
-
('mel_loss', epoch_mel_loss),
|
| 562 |
-
('frames/s', epoch_num_frames / epoch_time),
|
| 563 |
-
('took', epoch_time)]),
|
| 564 |
-
)
|
| 565 |
-
bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss,
|
| 566 |
-
epoch_time)
|
| 567 |
-
|
| 568 |
-
if epoch % args.validation_freq == 0:
|
| 569 |
-
validate(model, epoch, total_iter, criterion, val_loader,
|
| 570 |
-
distributed_run, batch_to_gpu, ema=False, local_rank=args.local_rank)
|
| 571 |
-
|
| 572 |
-
if args.ema_decay > 0:
|
| 573 |
-
validate(ema_model, epoch, total_iter, criterion, val_loader,
|
| 574 |
-
distributed_run, batch_to_gpu, args.local_rank, ema=True)
|
| 575 |
-
|
| 576 |
-
# save before making sched.step() for proper loading of LR
|
| 577 |
-
checkpointer.maybe_save(args, model, ema_model, optimizer, scaler,
|
| 578 |
-
epoch, total_iter, model_config)
|
| 579 |
-
logger.flush()
|
| 580 |
-
|
| 581 |
-
# Finished training
|
| 582 |
-
if len(bmark_stats) > 0:
|
| 583 |
-
log((), tb_total_steps=None, subset='train_avg',
|
| 584 |
-
data=bmark_stats.get(args.benchmark_epochs_num))
|
| 585 |
-
|
| 586 |
-
validate(model, None, total_iter, criterion, val_loader, distributed_run,
|
| 587 |
-
batch_to_gpu)
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
if __name__ == '__main__':
|
| 591 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fastpitch/utils_trainplot_transformers/transformer.py
DELETED
|
@@ -1,213 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
import torch
|
| 16 |
-
import torch.nn as nn
|
| 17 |
-
import torch.nn.functional as F
|
| 18 |
-
|
| 19 |
-
from common.utils import mask_from_lens
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class PositionalEmbedding(nn.Module):
|
| 23 |
-
def __init__(self, demb):
|
| 24 |
-
super(PositionalEmbedding, self).__init__()
|
| 25 |
-
self.demb = demb
|
| 26 |
-
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
|
| 27 |
-
self.register_buffer('inv_freq', inv_freq)
|
| 28 |
-
|
| 29 |
-
def forward(self, pos_seq, bsz=None):
|
| 30 |
-
sinusoid_inp = torch.matmul(torch.unsqueeze(pos_seq, -1),
|
| 31 |
-
torch.unsqueeze(self.inv_freq, 0))
|
| 32 |
-
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
|
| 33 |
-
if bsz is not None:
|
| 34 |
-
return pos_emb[None, :, :].expand(bsz, -1, -1)
|
| 35 |
-
else:
|
| 36 |
-
return pos_emb[None, :, :]
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class PositionwiseConvFF(nn.Module):
|
| 40 |
-
def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
|
| 41 |
-
super(PositionwiseConvFF, self).__init__()
|
| 42 |
-
|
| 43 |
-
self.d_model = d_model
|
| 44 |
-
self.d_inner = d_inner
|
| 45 |
-
self.dropout = dropout
|
| 46 |
-
|
| 47 |
-
self.CoreNet = nn.Sequential(
|
| 48 |
-
nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
|
| 49 |
-
nn.ReLU(),
|
| 50 |
-
# nn.Dropout(dropout), # worse convergence
|
| 51 |
-
nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
|
| 52 |
-
nn.Dropout(dropout),
|
| 53 |
-
)
|
| 54 |
-
self.layer_norm = nn.LayerNorm(d_model)
|
| 55 |
-
self.pre_lnorm = pre_lnorm
|
| 56 |
-
|
| 57 |
-
def forward(self, inp):
|
| 58 |
-
return self._forward(inp)
|
| 59 |
-
|
| 60 |
-
def _forward(self, inp):
|
| 61 |
-
if self.pre_lnorm:
|
| 62 |
-
# layer normalization + positionwise feed-forward
|
| 63 |
-
core_out = inp.transpose(1, 2)
|
| 64 |
-
core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype))
|
| 65 |
-
core_out = core_out.transpose(1, 2)
|
| 66 |
-
|
| 67 |
-
# residual connection
|
| 68 |
-
output = core_out + inp
|
| 69 |
-
else:
|
| 70 |
-
# positionwise feed-forward
|
| 71 |
-
core_out = inp.transpose(1, 2)
|
| 72 |
-
core_out = self.CoreNet(core_out)
|
| 73 |
-
core_out = core_out.transpose(1, 2)
|
| 74 |
-
|
| 75 |
-
# residual connection + layer normalization
|
| 76 |
-
output = self.layer_norm(inp + core_out).to(inp.dtype)
|
| 77 |
-
|
| 78 |
-
return output
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
class MultiHeadAttn(nn.Module):
|
| 82 |
-
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1,
|
| 83 |
-
pre_lnorm=False):
|
| 84 |
-
super(MultiHeadAttn, self).__init__()
|
| 85 |
-
|
| 86 |
-
self.n_head = n_head
|
| 87 |
-
self.d_model = d_model
|
| 88 |
-
self.d_head = d_head
|
| 89 |
-
self.scale = 1 / (d_head ** 0.5)
|
| 90 |
-
self.pre_lnorm = pre_lnorm
|
| 91 |
-
|
| 92 |
-
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
|
| 93 |
-
self.drop = nn.Dropout(dropout)
|
| 94 |
-
self.dropatt = nn.Dropout(dropatt)
|
| 95 |
-
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
| 96 |
-
self.layer_norm = nn.LayerNorm(d_model)
|
| 97 |
-
|
| 98 |
-
def forward(self, inp, attn_mask=None):
|
| 99 |
-
return self._forward(inp, attn_mask)
|
| 100 |
-
|
| 101 |
-
def _forward(self, inp, attn_mask=None):
|
| 102 |
-
residual = inp
|
| 103 |
-
|
| 104 |
-
if self.pre_lnorm:
|
| 105 |
-
# layer normalization
|
| 106 |
-
inp = self.layer_norm(inp)
|
| 107 |
-
|
| 108 |
-
n_head, d_head = self.n_head, self.d_head
|
| 109 |
-
|
| 110 |
-
head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)
|
| 111 |
-
head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
|
| 112 |
-
head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
|
| 113 |
-
head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
|
| 114 |
-
|
| 115 |
-
q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
| 116 |
-
k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
| 117 |
-
v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
| 118 |
-
|
| 119 |
-
attn_score = torch.bmm(q, k.transpose(1, 2))
|
| 120 |
-
attn_score.mul_(self.scale)
|
| 121 |
-
|
| 122 |
-
if attn_mask is not None:
|
| 123 |
-
attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype)
|
| 124 |
-
attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
|
| 125 |
-
attn_score.masked_fill_(attn_mask.to(torch.bool), -float('inf'))
|
| 126 |
-
|
| 127 |
-
attn_prob = F.softmax(attn_score, dim=2)
|
| 128 |
-
attn_prob = self.dropatt(attn_prob)
|
| 129 |
-
attn_vec = torch.bmm(attn_prob, v)
|
| 130 |
-
|
| 131 |
-
attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
|
| 132 |
-
attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(
|
| 133 |
-
inp.size(0), inp.size(1), n_head * d_head)
|
| 134 |
-
|
| 135 |
-
# linear projection
|
| 136 |
-
attn_out = self.o_net(attn_vec)
|
| 137 |
-
attn_out = self.drop(attn_out)
|
| 138 |
-
|
| 139 |
-
if self.pre_lnorm:
|
| 140 |
-
# residual connection
|
| 141 |
-
output = residual + attn_out
|
| 142 |
-
else:
|
| 143 |
-
# residual connection + layer normalization
|
| 144 |
-
output = self.layer_norm(residual + attn_out)
|
| 145 |
-
|
| 146 |
-
output = output.to(attn_out.dtype)
|
| 147 |
-
|
| 148 |
-
return output
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
class TransformerLayer(nn.Module):
|
| 152 |
-
def __init__(self, n_head, d_model, d_head, d_inner, kernel_size, dropout,
|
| 153 |
-
**kwargs):
|
| 154 |
-
super(TransformerLayer, self).__init__()
|
| 155 |
-
|
| 156 |
-
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
|
| 157 |
-
self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout,
|
| 158 |
-
pre_lnorm=kwargs.get('pre_lnorm'))
|
| 159 |
-
|
| 160 |
-
def forward(self, dec_inp, mask=None):
|
| 161 |
-
output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
|
| 162 |
-
output *= mask
|
| 163 |
-
output = self.pos_ff(output)
|
| 164 |
-
output *= mask
|
| 165 |
-
return output
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
class FFTransformer(nn.Module):
|
| 169 |
-
def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
|
| 170 |
-
dropout, dropatt, dropemb=0.0, embed_input=True,
|
| 171 |
-
n_embed=None, d_embed=None, padding_idx=0, pre_lnorm=False):
|
| 172 |
-
super(FFTransformer, self).__init__()
|
| 173 |
-
self.d_model = d_model
|
| 174 |
-
self.n_head = n_head
|
| 175 |
-
self.d_head = d_head
|
| 176 |
-
self.padding_idx = padding_idx
|
| 177 |
-
|
| 178 |
-
if embed_input:
|
| 179 |
-
self.word_emb = nn.Embedding(n_embed, d_embed or d_model,
|
| 180 |
-
padding_idx=self.padding_idx)
|
| 181 |
-
else:
|
| 182 |
-
self.word_emb = None
|
| 183 |
-
|
| 184 |
-
self.pos_emb = PositionalEmbedding(self.d_model)
|
| 185 |
-
self.drop = nn.Dropout(dropemb)
|
| 186 |
-
self.layers = nn.ModuleList()
|
| 187 |
-
|
| 188 |
-
for _ in range(n_layer):
|
| 189 |
-
self.layers.append(
|
| 190 |
-
TransformerLayer(
|
| 191 |
-
n_head, d_model, d_head, d_inner, kernel_size, dropout,
|
| 192 |
-
dropatt=dropatt, pre_lnorm=pre_lnorm)
|
| 193 |
-
)
|
| 194 |
-
|
| 195 |
-
def forward(self, dec_inp, seq_lens=None, conditioning=0):
|
| 196 |
-
if self.word_emb is None:
|
| 197 |
-
inp = dec_inp
|
| 198 |
-
mask = mask_from_lens(seq_lens).unsqueeze(2)
|
| 199 |
-
else:
|
| 200 |
-
inp = self.word_emb(dec_inp)
|
| 201 |
-
# [bsz x L x 1]
|
| 202 |
-
mask = (dec_inp != self.padding_idx).unsqueeze(2)
|
| 203 |
-
|
| 204 |
-
pos_seq = torch.arange(inp.size(1), device=inp.device).to(inp.dtype)
|
| 205 |
-
pos_emb = self.pos_emb(pos_seq) * mask
|
| 206 |
-
|
| 207 |
-
out = self.drop(inp + pos_emb + conditioning)
|
| 208 |
-
|
| 209 |
-
for layer in self.layers:
|
| 210 |
-
out = layer(out, mask=mask)
|
| 211 |
-
|
| 212 |
-
# out = self.drop(out)
|
| 213 |
-
return out, mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fastpitch/utils_trainplot_transformers/transformer_jit.py
DELETED
|
@@ -1,255 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
from typing import List, Optional
|
| 16 |
-
|
| 17 |
-
import torch
|
| 18 |
-
import torch.nn as nn
|
| 19 |
-
import torch.nn.functional as F
|
| 20 |
-
|
| 21 |
-
from common.utils import mask_from_lens
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class PositionalEmbedding(nn.Module):
|
| 25 |
-
def __init__(self, demb):
|
| 26 |
-
super(PositionalEmbedding, self).__init__()
|
| 27 |
-
self.demb = demb
|
| 28 |
-
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
|
| 29 |
-
self.register_buffer('inv_freq', inv_freq)
|
| 30 |
-
|
| 31 |
-
def forward(self, pos_seq, bsz: Optional[int] = None):
|
| 32 |
-
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
|
| 33 |
-
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
|
| 34 |
-
if bsz is not None:
|
| 35 |
-
return pos_emb[None, :, :].expand(bsz, -1, -1)
|
| 36 |
-
else:
|
| 37 |
-
return pos_emb[None, :, :]
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class PositionwiseFF(nn.Module):
|
| 41 |
-
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
|
| 42 |
-
super(PositionwiseFF, self).__init__()
|
| 43 |
-
|
| 44 |
-
self.d_model = d_model
|
| 45 |
-
self.d_inner = d_inner
|
| 46 |
-
self.dropout = dropout
|
| 47 |
-
|
| 48 |
-
self.CoreNet = nn.Sequential(
|
| 49 |
-
nn.Linear(d_model, d_inner), nn.ReLU(),
|
| 50 |
-
nn.Dropout(dropout),
|
| 51 |
-
nn.Linear(d_inner, d_model),
|
| 52 |
-
nn.Dropout(dropout),
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
self.layer_norm = nn.LayerNorm(d_model)
|
| 56 |
-
self.pre_lnorm = pre_lnorm
|
| 57 |
-
|
| 58 |
-
def forward(self, inp):
|
| 59 |
-
if self.pre_lnorm:
|
| 60 |
-
# layer normalization + positionwise feed-forward
|
| 61 |
-
core_out = self.CoreNet(self.layer_norm(inp))
|
| 62 |
-
|
| 63 |
-
# residual connection
|
| 64 |
-
output = core_out + inp
|
| 65 |
-
else:
|
| 66 |
-
# positionwise feed-forward
|
| 67 |
-
core_out = self.CoreNet(inp)
|
| 68 |
-
|
| 69 |
-
# residual connection + layer normalization
|
| 70 |
-
output = self.layer_norm(inp + core_out)
|
| 71 |
-
|
| 72 |
-
return output
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
class PositionwiseConvFF(nn.Module):
|
| 76 |
-
def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
|
| 77 |
-
super(PositionwiseConvFF, self).__init__()
|
| 78 |
-
|
| 79 |
-
self.d_model = d_model
|
| 80 |
-
self.d_inner = d_inner
|
| 81 |
-
self.dropout = dropout
|
| 82 |
-
|
| 83 |
-
self.CoreNet = nn.Sequential(
|
| 84 |
-
nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
|
| 85 |
-
nn.ReLU(),
|
| 86 |
-
# nn.Dropout(dropout), # worse convergence
|
| 87 |
-
nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
|
| 88 |
-
nn.Dropout(dropout),
|
| 89 |
-
)
|
| 90 |
-
self.layer_norm = nn.LayerNorm(d_model)
|
| 91 |
-
self.pre_lnorm = pre_lnorm
|
| 92 |
-
|
| 93 |
-
def forward(self, inp):
|
| 94 |
-
if self.pre_lnorm:
|
| 95 |
-
# layer normalization + positionwise feed-forward
|
| 96 |
-
core_out = inp.transpose(1, 2)
|
| 97 |
-
core_out = self.CoreNet(self.layer_norm(core_out))
|
| 98 |
-
core_out = core_out.transpose(1, 2)
|
| 99 |
-
|
| 100 |
-
# residual connection
|
| 101 |
-
output = core_out + inp
|
| 102 |
-
else:
|
| 103 |
-
# positionwise feed-forward
|
| 104 |
-
core_out = inp.transpose(1, 2)
|
| 105 |
-
core_out = self.CoreNet(core_out)
|
| 106 |
-
core_out = core_out.transpose(1, 2)
|
| 107 |
-
|
| 108 |
-
# residual connection + layer normalization
|
| 109 |
-
output = self.layer_norm(inp + core_out)
|
| 110 |
-
|
| 111 |
-
return output
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
class MultiHeadAttn(nn.Module):
|
| 115 |
-
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1,
|
| 116 |
-
pre_lnorm=False):
|
| 117 |
-
super(MultiHeadAttn, self).__init__()
|
| 118 |
-
|
| 119 |
-
self.n_head = n_head
|
| 120 |
-
self.d_model = d_model
|
| 121 |
-
self.d_head = d_head
|
| 122 |
-
self.scale = 1 / (d_head ** 0.5)
|
| 123 |
-
self.dropout = dropout
|
| 124 |
-
self.pre_lnorm = pre_lnorm
|
| 125 |
-
|
| 126 |
-
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
|
| 127 |
-
self.drop = nn.Dropout(dropout)
|
| 128 |
-
self.dropatt = nn.Dropout(dropatt)
|
| 129 |
-
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
| 130 |
-
self.layer_norm = nn.LayerNorm(d_model)
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
def forward(self, inp, attn_mask: Optional[torch.Tensor] = None):
|
| 134 |
-
residual = inp
|
| 135 |
-
|
| 136 |
-
if self.pre_lnorm:
|
| 137 |
-
# layer normalization
|
| 138 |
-
inp = self.layer_norm(inp)
|
| 139 |
-
|
| 140 |
-
n_head, d_head = self.n_head, self.d_head
|
| 141 |
-
|
| 142 |
-
head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=-1)
|
| 143 |
-
head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
|
| 144 |
-
head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
|
| 145 |
-
head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
|
| 146 |
-
|
| 147 |
-
q = head_q.permute(0, 2, 1, 3).reshape(-1, inp.size(1), d_head)
|
| 148 |
-
k = head_k.permute(0, 2, 1, 3).reshape(-1, inp.size(1), d_head)
|
| 149 |
-
v = head_v.permute(0, 2, 1, 3).reshape(-1, inp.size(1), d_head)
|
| 150 |
-
|
| 151 |
-
attn_score = torch.bmm(q, k.transpose(1, 2))
|
| 152 |
-
attn_score.mul_(self.scale)
|
| 153 |
-
|
| 154 |
-
if attn_mask is not None:
|
| 155 |
-
attn_mask = attn_mask.unsqueeze(1)
|
| 156 |
-
attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
|
| 157 |
-
attn_score.masked_fill_(attn_mask, -float('inf'))
|
| 158 |
-
|
| 159 |
-
attn_prob = F.softmax(attn_score, dim=2)
|
| 160 |
-
attn_prob = self.dropatt(attn_prob)
|
| 161 |
-
attn_vec = torch.bmm(attn_prob, v)
|
| 162 |
-
|
| 163 |
-
attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
|
| 164 |
-
attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(
|
| 165 |
-
inp.size(0), inp.size(1), n_head * d_head)
|
| 166 |
-
|
| 167 |
-
# linear projection
|
| 168 |
-
attn_out = self.o_net(attn_vec)
|
| 169 |
-
attn_out = self.drop(attn_out)
|
| 170 |
-
|
| 171 |
-
if self.pre_lnorm:
|
| 172 |
-
# residual connection
|
| 173 |
-
output = residual + attn_out
|
| 174 |
-
else:
|
| 175 |
-
# residual connection + layer normalization
|
| 176 |
-
|
| 177 |
-
# XXX Running TorchScript on 20.02 and 20.03 containers crashes here
|
| 178 |
-
# XXX Works well with 20.01-py3 container.
|
| 179 |
-
# XXX dirty fix is:
|
| 180 |
-
# XXX output = self.layer_norm(residual + attn_out).half()
|
| 181 |
-
output = self.layer_norm(residual + attn_out)
|
| 182 |
-
|
| 183 |
-
return output
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
class TransformerLayer(nn.Module):
|
| 187 |
-
def __init__(self, n_head, d_model, d_head, d_inner, kernel_size, dropout,
|
| 188 |
-
**kwargs):
|
| 189 |
-
super(TransformerLayer, self).__init__()
|
| 190 |
-
|
| 191 |
-
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
|
| 192 |
-
self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout,
|
| 193 |
-
pre_lnorm=kwargs.get('pre_lnorm'))
|
| 194 |
-
|
| 195 |
-
def forward(self, dec_inp, mask):
|
| 196 |
-
output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
|
| 197 |
-
output *= mask
|
| 198 |
-
output = self.pos_ff(output)
|
| 199 |
-
output *= mask
|
| 200 |
-
return output
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
class FFTransformer(nn.Module):
|
| 204 |
-
def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
|
| 205 |
-
dropout, dropatt, dropemb=0.0, embed_input=True,
|
| 206 |
-
n_embed=None, d_embed=None, padding_idx=0, pre_lnorm=False):
|
| 207 |
-
super(FFTransformer, self).__init__()
|
| 208 |
-
self.d_model = d_model
|
| 209 |
-
self.n_head = n_head
|
| 210 |
-
self.d_head = d_head
|
| 211 |
-
self.padding_idx = padding_idx
|
| 212 |
-
self.n_embed = n_embed
|
| 213 |
-
|
| 214 |
-
self.embed_input = embed_input
|
| 215 |
-
if embed_input:
|
| 216 |
-
print(padding_idx) #########################################
|
| 217 |
-
self.word_emb = nn.Embedding(n_embed, d_embed or d_model,
|
| 218 |
-
padding_idx=self.padding_idx)
|
| 219 |
-
else:
|
| 220 |
-
self.word_emb = nn.Identity()
|
| 221 |
-
|
| 222 |
-
self.pos_emb = PositionalEmbedding(self.d_model)
|
| 223 |
-
self.drop = nn.Dropout(dropemb)
|
| 224 |
-
self.layers = nn.ModuleList()
|
| 225 |
-
|
| 226 |
-
for _ in range(n_layer):
|
| 227 |
-
self.layers.append(
|
| 228 |
-
TransformerLayer(
|
| 229 |
-
n_head, d_model, d_head, d_inner, kernel_size, dropout,
|
| 230 |
-
dropatt=dropatt, pre_lnorm=pre_lnorm)
|
| 231 |
-
)
|
| 232 |
-
|
| 233 |
-
def forward(self, dec_inp, seq_lens: Optional[torch.Tensor] = None,
|
| 234 |
-
conditioning: Optional[torch.Tensor] = None):
|
| 235 |
-
if not self.embed_input:
|
| 236 |
-
inp = dec_inp
|
| 237 |
-
assert seq_lens is not None
|
| 238 |
-
mask = mask_from_lens(seq_lens).unsqueeze(2)
|
| 239 |
-
else:
|
| 240 |
-
inp = self.word_emb(dec_inp)
|
| 241 |
-
# [bsz x L x 1]
|
| 242 |
-
mask = (dec_inp != self.padding_idx).unsqueeze(2)
|
| 243 |
-
|
| 244 |
-
pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype)
|
| 245 |
-
pos_emb = self.pos_emb(pos_seq) * mask
|
| 246 |
-
if conditioning is not None:
|
| 247 |
-
out = self.drop(inp + pos_emb + conditioning)
|
| 248 |
-
else:
|
| 249 |
-
out = self.drop(inp + pos_emb)
|
| 250 |
-
|
| 251 |
-
for layer in self.layers:
|
| 252 |
-
out = layer(out, mask=mask)
|
| 253 |
-
|
| 254 |
-
# out = self.drop(out)
|
| 255 |
-
return out, mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fastpitch/utils_trainplot_transformers/utils.py
DELETED
|
@@ -1,291 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
# MIT License
|
| 16 |
-
#
|
| 17 |
-
# Copyright (c) 2020 Jungil Kong
|
| 18 |
-
#
|
| 19 |
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 20 |
-
# of this software and associated documentation files (the "Software"), to deal
|
| 21 |
-
# in the Software without restriction, including without limitation the rights
|
| 22 |
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 23 |
-
# copies of the Software, and to permit persons to whom the Software is
|
| 24 |
-
# furnished to do so, subject to the following conditions:
|
| 25 |
-
#
|
| 26 |
-
# The above copyright notice and this permission notice shall be included in all
|
| 27 |
-
# copies or substantial portions of the Software.
|
| 28 |
-
#
|
| 29 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 30 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 31 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 32 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 33 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 34 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 35 |
-
# SOFTWARE.
|
| 36 |
-
|
| 37 |
-
# The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
|
| 38 |
-
# init_weights, get_padding, AttrDict
|
| 39 |
-
|
| 40 |
-
import ctypes
|
| 41 |
-
import glob
|
| 42 |
-
import os
|
| 43 |
-
import re
|
| 44 |
-
import shutil
|
| 45 |
-
import warnings
|
| 46 |
-
from collections import defaultdict, OrderedDict
|
| 47 |
-
from pathlib import Path
|
| 48 |
-
from typing import Optional
|
| 49 |
-
|
| 50 |
-
import librosa
|
| 51 |
-
import numpy as np
|
| 52 |
-
|
| 53 |
-
import torch
|
| 54 |
-
import torch.distributed as dist
|
| 55 |
-
from scipy.io.wavfile import read
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def mask_from_lens(lens, max_len: Optional[int] = None):
|
| 59 |
-
if max_len is None:
|
| 60 |
-
max_len = lens.max()
|
| 61 |
-
ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
|
| 62 |
-
mask = torch.lt(ids, lens.unsqueeze(1))
|
| 63 |
-
return mask
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def load_wav(full_path, torch_tensor=False):
|
| 67 |
-
import soundfile # flac
|
| 68 |
-
data, sampling_rate = soundfile.read(full_path, dtype='int16')
|
| 69 |
-
if torch_tensor:
|
| 70 |
-
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
| 71 |
-
else:
|
| 72 |
-
return data, sampling_rate
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def load_wav_to_torch(full_path, force_sampling_rate=None):
|
| 76 |
-
if force_sampling_rate is not None:
|
| 77 |
-
data, sampling_rate = librosa.load(full_path, sr=force_sampling_rate)
|
| 78 |
-
else:
|
| 79 |
-
sampling_rate, data = read(full_path)
|
| 80 |
-
|
| 81 |
-
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def load_filepaths_and_text(dataset_path, fnames, has_speakers=False, split="|"):
|
| 85 |
-
def split_line(root, line):
|
| 86 |
-
parts = line.strip().split(split)
|
| 87 |
-
if has_speakers:
|
| 88 |
-
paths, non_paths = parts[:-2], parts[-2:]
|
| 89 |
-
else:
|
| 90 |
-
paths, non_paths = parts[:-1], parts[-1:]
|
| 91 |
-
return tuple(str(Path(root, p)) for p in paths) + tuple(non_paths)
|
| 92 |
-
|
| 93 |
-
fpaths_and_text = []
|
| 94 |
-
for fname in fnames:
|
| 95 |
-
with open(fname, encoding='utf-8') as f:
|
| 96 |
-
fpaths_and_text += [split_line(dataset_path, line) for line in f]
|
| 97 |
-
return fpaths_and_text
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def to_gpu(x):
|
| 101 |
-
x = x.contiguous()
|
| 102 |
-
return x.cuda(non_blocking=True) if torch.cuda.is_available() else x
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
def l2_promote():
|
| 106 |
-
_libcudart = ctypes.CDLL('libcudart.so')
|
| 107 |
-
# Set device limit on the current device
|
| 108 |
-
# cudaLimitMaxL2FetchGranularity = 0x05
|
| 109 |
-
pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))
|
| 110 |
-
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
|
| 111 |
-
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
|
| 112 |
-
assert pValue.contents.value == 128
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def prepare_tmp(path):
|
| 116 |
-
if path is None:
|
| 117 |
-
return
|
| 118 |
-
p = Path(path)
|
| 119 |
-
if p.is_dir():
|
| 120 |
-
warnings.warn(f'{p} exists. Removing...')
|
| 121 |
-
shutil.rmtree(p, ignore_errors=True)
|
| 122 |
-
p.mkdir(parents=False, exist_ok=False)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def print_once(*msg):
|
| 126 |
-
if not dist.is_initialized() or dist.get_rank() == 0:
|
| 127 |
-
print(*msg)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
def init_weights(m, mean=0.0, std=0.01):
|
| 131 |
-
classname = m.__class__.__name__
|
| 132 |
-
if classname.find("Conv") != -1:
|
| 133 |
-
m.weight.data.normal_(mean, std)
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def get_padding(kernel_size, dilation=1):
|
| 137 |
-
return int((kernel_size*dilation - dilation)/2)
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
class AttrDict(dict):
|
| 141 |
-
def __init__(self, *args, **kwargs):
|
| 142 |
-
super(AttrDict, self).__init__(*args, **kwargs)
|
| 143 |
-
self.__dict__ = self
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
class DefaultAttrDict(defaultdict):
|
| 147 |
-
def __init__(self, *args, **kwargs):
|
| 148 |
-
super(DefaultAttrDict, self).__init__(*args, **kwargs)
|
| 149 |
-
self.__dict__ = self
|
| 150 |
-
|
| 151 |
-
def __getattr__(self, item):
|
| 152 |
-
return self[item]
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
class BenchmarkStats:
|
| 156 |
-
""" Tracks statistics used for benchmarking. """
|
| 157 |
-
def __init__(self):
|
| 158 |
-
self.num_frames = []
|
| 159 |
-
self.losses = []
|
| 160 |
-
self.mel_losses = []
|
| 161 |
-
self.took = []
|
| 162 |
-
|
| 163 |
-
def update(self, num_frames, losses, mel_losses, took):
|
| 164 |
-
self.num_frames.append(num_frames)
|
| 165 |
-
self.losses.append(losses)
|
| 166 |
-
self.mel_losses.append(mel_losses)
|
| 167 |
-
self.took.append(took)
|
| 168 |
-
|
| 169 |
-
def get(self, n_epochs):
|
| 170 |
-
frames_s = sum(self.num_frames[-n_epochs:]) / sum(self.took[-n_epochs:])
|
| 171 |
-
return {'frames/s': frames_s,
|
| 172 |
-
'loss': np.mean(self.losses[-n_epochs:]),
|
| 173 |
-
'mel_loss': np.mean(self.mel_losses[-n_epochs:]),
|
| 174 |
-
'took': np.mean(self.took[-n_epochs:]),
|
| 175 |
-
'benchmark_epochs_num': n_epochs}
|
| 176 |
-
|
| 177 |
-
def __len__(self):
|
| 178 |
-
return len(self.losses)
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
class Checkpointer:
|
| 182 |
-
|
| 183 |
-
def __init__(self, save_dir, keep_milestones=[]):
|
| 184 |
-
self.save_dir = save_dir
|
| 185 |
-
self.keep_milestones = keep_milestones
|
| 186 |
-
|
| 187 |
-
find = lambda name: [
|
| 188 |
-
(int(re.search("_(\d+).pt", fn).group(1)), fn)
|
| 189 |
-
for fn in glob.glob(f"{save_dir}/{name}_checkpoint_*.pt")]
|
| 190 |
-
|
| 191 |
-
tracked = sorted(find("FastPitch"), key=lambda t: t[0])
|
| 192 |
-
self.tracked = OrderedDict(tracked)
|
| 193 |
-
|
| 194 |
-
def last_checkpoint(self, output):
|
| 195 |
-
|
| 196 |
-
def corrupted(fpath):
|
| 197 |
-
try:
|
| 198 |
-
torch.load(fpath, map_location="cpu")
|
| 199 |
-
return False
|
| 200 |
-
except:
|
| 201 |
-
warnings.warn(f"Cannot load {fpath}")
|
| 202 |
-
return True
|
| 203 |
-
|
| 204 |
-
saved = sorted(
|
| 205 |
-
glob.glob(f"{output}/FastPitch_checkpoint_*.pt"),
|
| 206 |
-
key=lambda f: int(re.search("_(\d+).pt", f).group(1)))
|
| 207 |
-
|
| 208 |
-
if len(saved) >= 1 and not corrupted(saved[-1]):
|
| 209 |
-
return saved[-1]
|
| 210 |
-
elif len(saved) >= 2:
|
| 211 |
-
return saved[-2]
|
| 212 |
-
else:
|
| 213 |
-
return None
|
| 214 |
-
|
| 215 |
-
def maybe_load(self, model, optimizer, scaler, train_state, args,
|
| 216 |
-
ema_model=None):
|
| 217 |
-
|
| 218 |
-
assert args.checkpoint_path is None or args.resume is False, (
|
| 219 |
-
"Specify a single checkpoint source")
|
| 220 |
-
|
| 221 |
-
fpath = None
|
| 222 |
-
if args.checkpoint_path is not None:
|
| 223 |
-
fpath = args.checkpoint_path
|
| 224 |
-
self.tracked = OrderedDict() # Do not track/delete prev ckpts
|
| 225 |
-
elif args.resume:
|
| 226 |
-
fpath = self.last_checkpoint(args.output)
|
| 227 |
-
|
| 228 |
-
if fpath is None:
|
| 229 |
-
return
|
| 230 |
-
|
| 231 |
-
print_once(f"Loading model and optimizer state from {fpath}")
|
| 232 |
-
ckpt = torch.load(fpath, map_location="cpu")
|
| 233 |
-
train_state["epoch"] = ckpt["epoch"] + 1
|
| 234 |
-
train_state["total_iter"] = ckpt["iteration"]
|
| 235 |
-
|
| 236 |
-
no_pref = lambda sd: {re.sub("^module.", "", k): v for k, v in sd.items()}
|
| 237 |
-
unwrap = lambda m: getattr(m, "module", m)
|
| 238 |
-
|
| 239 |
-
unwrap(model).load_state_dict(no_pref(ckpt["state_dict"]))
|
| 240 |
-
|
| 241 |
-
if ema_model is not None:
|
| 242 |
-
unwrap(ema_model).load_state_dict(no_pref(ckpt["ema_state_dict"]))
|
| 243 |
-
|
| 244 |
-
optimizer.load_state_dict(ckpt["optimizer"])
|
| 245 |
-
|
| 246 |
-
if "scaler" in ckpt:
|
| 247 |
-
scaler.load_state_dict(ckpt["scaler"])
|
| 248 |
-
else:
|
| 249 |
-
warnings.warn("AMP scaler state missing from the checkpoint.")
|
| 250 |
-
|
| 251 |
-
def maybe_save(self, args, model, ema_model, optimizer, scaler, epoch,
|
| 252 |
-
total_iter, config):
|
| 253 |
-
|
| 254 |
-
intermediate = (args.epochs_per_checkpoint > 0
|
| 255 |
-
and epoch % args.epochs_per_checkpoint == 0)
|
| 256 |
-
final = epoch == args.epochs
|
| 257 |
-
|
| 258 |
-
if not intermediate and not final and epoch not in self.keep_milestones:
|
| 259 |
-
return
|
| 260 |
-
|
| 261 |
-
rank = 0
|
| 262 |
-
if dist.is_initialized():
|
| 263 |
-
dist.barrier()
|
| 264 |
-
rank = dist.get_rank()
|
| 265 |
-
|
| 266 |
-
if rank != 0:
|
| 267 |
-
return
|
| 268 |
-
|
| 269 |
-
unwrap = lambda m: getattr(m, "module", m)
|
| 270 |
-
ckpt = {"epoch": epoch,
|
| 271 |
-
"iteration": total_iter,
|
| 272 |
-
"config": config,
|
| 273 |
-
"train_setup": args.__dict__,
|
| 274 |
-
"state_dict": unwrap(model).state_dict(),
|
| 275 |
-
"optimizer": optimizer.state_dict(),
|
| 276 |
-
"scaler": scaler.state_dict()}
|
| 277 |
-
if ema_model is not None:
|
| 278 |
-
ckpt["ema_state_dict"] = unwrap(ema_model).state_dict()
|
| 279 |
-
|
| 280 |
-
fpath = Path(args.output, f"FastPitch_checkpoint_{epoch}.pt")
|
| 281 |
-
print(f"Saving model and optimizer state at epoch {epoch} to {fpath}")
|
| 282 |
-
torch.save(ckpt, fpath)
|
| 283 |
-
|
| 284 |
-
# Remove old checkpoints; keep milestones and the last two
|
| 285 |
-
self.tracked[epoch] = fpath
|
| 286 |
-
for epoch in set(list(self.tracked)[:-2]) - set(self.keep_milestones):
|
| 287 |
-
try:
|
| 288 |
-
os.remove(self.tracked[epoch])
|
| 289 |
-
except:
|
| 290 |
-
pass
|
| 291 |
-
del self.tracked[epoch]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gradio_gui.py
DELETED
|
@@ -1,74 +0,0 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
import syn_hifigan as syn
|
| 3 |
-
#import syn_k_univnet_multi as syn
|
| 4 |
-
import os, tempfile
|
| 5 |
-
|
| 6 |
-
languages = {"South Sámi":0,
|
| 7 |
-
"North Sámi":1,
|
| 8 |
-
"Lule Sámi":2}
|
| 9 |
-
|
| 10 |
-
speakers={"aj0": 0,
|
| 11 |
-
"aj1": 1,
|
| 12 |
-
"am": 2,
|
| 13 |
-
"bi": 3,
|
| 14 |
-
"kd": 4,
|
| 15 |
-
"ln": 5,
|
| 16 |
-
"lo": 6,
|
| 17 |
-
"ms": 7,
|
| 18 |
-
"mu": 8,
|
| 19 |
-
"sa": 9
|
| 20 |
-
}
|
| 21 |
-
public=False
|
| 22 |
-
|
| 23 |
-
tempdir = tempfile.gettempdir()
|
| 24 |
-
|
| 25 |
-
tts = syn.Synthesizer()
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def speak(text, language,speaker,l_weight, s_weight, pace, postfilter): #pitch_shift,pitch_std):
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
# text frontend not implemented...
|
| 34 |
-
text = text.replace("...", "…")
|
| 35 |
-
print(speakers[speaker])
|
| 36 |
-
audio = tts.speak(text, output_file=f'{tempdir}/tmp', lang=languages[language],
|
| 37 |
-
spkr=speakers[speaker], l_weight=l_weight, s_weight=s_weight,
|
| 38 |
-
pace=pace, clarity=postfilter)
|
| 39 |
-
|
| 40 |
-
if not public:
|
| 41 |
-
try:
|
| 42 |
-
os.system("play "+tempdir+"/tmp.wav &")
|
| 43 |
-
except:
|
| 44 |
-
pass
|
| 45 |
-
|
| 46 |
-
return (22050, audio)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
controls = []
|
| 51 |
-
controls.append(gr.Textbox(label="text", value="Suohtas duinna deaivvadit."))
|
| 52 |
-
controls.append(gr.Dropdown(list(languages.keys()), label="language", value="North Sámi"))
|
| 53 |
-
controls.append(gr.Dropdown(list(speakers.keys()), label="speaker", value="ms"))
|
| 54 |
-
|
| 55 |
-
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="language weight"))
|
| 56 |
-
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="speaker weight"))
|
| 57 |
-
|
| 58 |
-
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1.0, label="speech rate"))
|
| 59 |
-
controls.append(gr.Slider(minimum=0., maximum=2, step=0.05, value=1.0, label="post-processing"))
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
tts_gui = gr.Interface(
|
| 65 |
-
fn=speak,
|
| 66 |
-
inputs=controls,
|
| 67 |
-
outputs= gr.Audio(label="output"),
|
| 68 |
-
live=False
|
| 69 |
-
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
if __name__ == "__main__":
|
| 74 |
-
tts_gui.launch(share=public)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gradio_gui_katri.py
DELETED
|
@@ -1,73 +0,0 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
#import syn_hifigan as syn
|
| 3 |
-
import syn_k_univnet_multi as syn
|
| 4 |
-
import os, tempfile
|
| 5 |
-
|
| 6 |
-
languages = {"South Sámi":0,
|
| 7 |
-
"North Sámi":1,
|
| 8 |
-
"Lule Sámi":2}
|
| 9 |
-
|
| 10 |
-
speakers={"aj0": 0,
|
| 11 |
-
"aj1": 1,
|
| 12 |
-
"am": 2,
|
| 13 |
-
"bi": 3,
|
| 14 |
-
"kd": 4,
|
| 15 |
-
"ln": 5,
|
| 16 |
-
"lo": 6,
|
| 17 |
-
"ms": 7,
|
| 18 |
-
"mu": 8,
|
| 19 |
-
"sa": 9
|
| 20 |
-
}
|
| 21 |
-
public=True
|
| 22 |
-
|
| 23 |
-
tempdir = tempfile.gettempdir()
|
| 24 |
-
|
| 25 |
-
tts = syn.Synthesizer()
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def speak(text, language,speaker,l_weight, s_weight, pace): #pitch_shift,pitch_std):
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
# text frontend not implemented...
|
| 34 |
-
text = text.replace("...", "…")
|
| 35 |
-
print(speakers[speaker])
|
| 36 |
-
audio = tts.speak(text, output_file=f'{tempdir}/tmp', lang=languages[language],
|
| 37 |
-
spkr=speakers[speaker], l_weight=l_weight, s_weight=s_weight,
|
| 38 |
-
pace=pace)
|
| 39 |
-
|
| 40 |
-
if not public:
|
| 41 |
-
try:
|
| 42 |
-
os.system("play "+tempdir+"/tmp.wav &")
|
| 43 |
-
except:
|
| 44 |
-
pass
|
| 45 |
-
|
| 46 |
-
return (22050, audio)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
controls = []
|
| 51 |
-
controls.append(gr.Textbox(label="text", value="Suohtas duinna deaivvadit."))
|
| 52 |
-
controls.append(gr.Dropdown(list(languages.keys()), label="language", value="North Sámi"))
|
| 53 |
-
controls.append(gr.Dropdown(list(speakers.keys()), label="speaker", value="ms"))
|
| 54 |
-
|
| 55 |
-
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="Language weight"))
|
| 56 |
-
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="Speaker weight"))
|
| 57 |
-
#controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1.0, label="Pitch variance"))
|
| 58 |
-
controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1.0, label="speech rate"))
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
tts_gui = gr.Interface(
|
| 64 |
-
fn=speak,
|
| 65 |
-
inputs=controls,
|
| 66 |
-
outputs= gr.Audio(label="output"),
|
| 67 |
-
live=False
|
| 68 |
-
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
if __name__ == "__main__":
|
| 73 |
-
tts_gui.launch(share=public)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prepare_dataset.py
DELETED
|
@@ -1,180 +0,0 @@
|
|
| 1 |
-
# *****************************************************************************
|
| 2 |
-
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Redistribution and use in source and binary forms, with or without
|
| 5 |
-
# modification, are permitted provided that the following conditions are met:
|
| 6 |
-
# * Redistributions of source code must retain the above copyright
|
| 7 |
-
# notice, this list of conditions and the following disclaimer.
|
| 8 |
-
# * Redistributions in binary form must reproduce the above copyright
|
| 9 |
-
# notice, this list of conditions and the following disclaimer in the
|
| 10 |
-
# documentation and/or other materials provided with the distribution.
|
| 11 |
-
# * Neither the name of the NVIDIA CORPORATION nor the
|
| 12 |
-
# names of its contributors may be used to endorse or promote products
|
| 13 |
-
# derived from this software without specific prior written permission.
|
| 14 |
-
#
|
| 15 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 16 |
-
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 17 |
-
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 18 |
-
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
| 19 |
-
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
| 20 |
-
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
| 21 |
-
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
| 22 |
-
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 23 |
-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
| 24 |
-
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 25 |
-
#
|
| 26 |
-
# *****************************************************************************
|
| 27 |
-
|
| 28 |
-
import argparse
|
| 29 |
-
import time
|
| 30 |
-
from pathlib import Path
|
| 31 |
-
|
| 32 |
-
import torch
|
| 33 |
-
import tqdm
|
| 34 |
-
import dllogger as DLLogger
|
| 35 |
-
from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
|
| 36 |
-
from torch.utils.data import DataLoader
|
| 37 |
-
|
| 38 |
-
from fastpitch.data_function import TTSCollate, TTSDataset
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def parse_args(parser):
|
| 42 |
-
"""
|
| 43 |
-
Parse commandline arguments.
|
| 44 |
-
"""
|
| 45 |
-
parser.add_argument('-d', '--dataset-path', type=str,
|
| 46 |
-
default='./', help='Path to dataset')
|
| 47 |
-
parser.add_argument('--wav-text-filelists', required=True, nargs='+',
|
| 48 |
-
type=str, help='Files with audio paths and text')
|
| 49 |
-
parser.add_argument('--extract-mels', action='store_true',
|
| 50 |
-
help='Calculate spectrograms from .wav files')
|
| 51 |
-
parser.add_argument('--extract-pitch', action='store_true',
|
| 52 |
-
help='Extract pitch')
|
| 53 |
-
parser.add_argument('--save-alignment-priors', action='store_true',
|
| 54 |
-
help='Pre-calculate diagonal matrices of alignment of text to audio')
|
| 55 |
-
parser.add_argument('--log-file', type=str, default='preproc_log.json',
|
| 56 |
-
help='Filename for logging')
|
| 57 |
-
parser.add_argument('--n-speakers', type=int, default=1)
|
| 58 |
-
parser.add_argument('--n-languages', type=int, default=1)
|
| 59 |
-
# Mel extraction
|
| 60 |
-
parser.add_argument('--max-wav-value', default=32768.0, type=float,
|
| 61 |
-
help='Maximum audiowave value')
|
| 62 |
-
parser.add_argument('--sampling-rate', default=22050, type=int,
|
| 63 |
-
help='Sampling rate')
|
| 64 |
-
parser.add_argument('--filter-length', default=1024, type=int,
|
| 65 |
-
help='Filter length')
|
| 66 |
-
parser.add_argument('--hop-length', default=256, type=int,
|
| 67 |
-
help='Hop (stride) length')
|
| 68 |
-
parser.add_argument('--win-length', default=1024, type=int,
|
| 69 |
-
help='Window length')
|
| 70 |
-
parser.add_argument('--mel-fmin', default=0.0, type=float,
|
| 71 |
-
help='Minimum mel frequency')
|
| 72 |
-
parser.add_argument('--mel-fmax', default=8000.0, type=float,
|
| 73 |
-
help='Maximum mel frequency')
|
| 74 |
-
parser.add_argument('--n-mel-channels', type=int, default=80)
|
| 75 |
-
# Pitch extraction
|
| 76 |
-
parser.add_argument('--f0-method', default='pyin', type=str,
|
| 77 |
-
choices=['pyin'], help='F0 estimation method')
|
| 78 |
-
parser.add_argument('--pitch-mean', default='214', type=float, ###
|
| 79 |
-
help='F0 estimation method')
|
| 80 |
-
parser.add_argument('--pitch-std', default='65', type=float, ####
|
| 81 |
-
help='F0 estimation method')
|
| 82 |
-
# Performance
|
| 83 |
-
parser.add_argument('-b', '--batch-size', default=1, type=int)
|
| 84 |
-
parser.add_argument('--n-workers', type=int, default=16)
|
| 85 |
-
return parser
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def main():
|
| 89 |
-
parser = argparse.ArgumentParser(description='FastPitch Data Pre-processing')
|
| 90 |
-
parser = parse_args(parser)
|
| 91 |
-
args, unk_args = parser.parse_known_args()
|
| 92 |
-
if len(unk_args) > 0:
|
| 93 |
-
raise ValueError(f'Invalid options {unk_args}')
|
| 94 |
-
|
| 95 |
-
DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, Path(args.dataset_path, args.log_file)),
|
| 96 |
-
StdOutBackend(Verbosity.VERBOSE)])
|
| 97 |
-
for k, v in vars(args).items():
|
| 98 |
-
DLLogger.log(step="PARAMETER", data={k: v})
|
| 99 |
-
DLLogger.flush()
|
| 100 |
-
|
| 101 |
-
if args.extract_mels:
|
| 102 |
-
Path(args.dataset_path, 'mels').mkdir(parents=False, exist_ok=True)
|
| 103 |
-
|
| 104 |
-
if args.extract_pitch:
|
| 105 |
-
Path(args.dataset_path, 'pitch').mkdir(parents=False, exist_ok=True)
|
| 106 |
-
|
| 107 |
-
if args.save_alignment_priors:
|
| 108 |
-
Path(args.dataset_path, 'alignment_priors').mkdir(parents=False, exist_ok=True)
|
| 109 |
-
|
| 110 |
-
for filelist in args.wav_text_filelists:
|
| 111 |
-
|
| 112 |
-
print(f'Processing {filelist}...')
|
| 113 |
-
|
| 114 |
-
dataset = TTSDataset(
|
| 115 |
-
args.dataset_path,
|
| 116 |
-
filelist,
|
| 117 |
-
text_cleaners=['basic_cleaners'],
|
| 118 |
-
n_mel_channels=args.n_mel_channels,
|
| 119 |
-
p_arpabet=0.0,
|
| 120 |
-
n_speakers=args.n_speakers,
|
| 121 |
-
n_languages=args.n_languages,
|
| 122 |
-
load_mel_from_disk=False,
|
| 123 |
-
load_pitch_from_disk=False,
|
| 124 |
-
pitch_mean=args.pitch_mean,
|
| 125 |
-
pitch_std=args.pitch_std,
|
| 126 |
-
max_wav_value=args.max_wav_value,
|
| 127 |
-
sampling_rate=args.sampling_rate,
|
| 128 |
-
filter_length=args.filter_length,
|
| 129 |
-
hop_length=args.hop_length,
|
| 130 |
-
win_length=args.win_length,
|
| 131 |
-
mel_fmin=args.mel_fmin,
|
| 132 |
-
mel_fmax=args.mel_fmax,
|
| 133 |
-
betabinomial_online_dir=None,
|
| 134 |
-
pitch_online_dir=None,
|
| 135 |
-
pitch_online_method=args.f0_method)
|
| 136 |
-
|
| 137 |
-
data_loader = DataLoader(
|
| 138 |
-
dataset,
|
| 139 |
-
batch_size=args.batch_size,
|
| 140 |
-
shuffle=False,
|
| 141 |
-
sampler=None,
|
| 142 |
-
num_workers=args.n_workers,
|
| 143 |
-
collate_fn=TTSCollate(),
|
| 144 |
-
pin_memory=False,
|
| 145 |
-
drop_last=False)
|
| 146 |
-
|
| 147 |
-
all_filenames = set()
|
| 148 |
-
for i, batch in enumerate(tqdm.tqdm(data_loader)):
|
| 149 |
-
tik = time.time()
|
| 150 |
-
|
| 151 |
-
_, input_lens, mels, mel_lens, _, pitch, _, _, _, attn_prior, fpaths = batch
|
| 152 |
-
|
| 153 |
-
# Ensure filenames are unique
|
| 154 |
-
for p in fpaths:
|
| 155 |
-
fname = Path(p).name
|
| 156 |
-
if fname in all_filenames:
|
| 157 |
-
raise ValueError(f'Filename is not unique: {fname}')
|
| 158 |
-
all_filenames.add(fname)
|
| 159 |
-
|
| 160 |
-
if args.extract_mels:
|
| 161 |
-
for j, mel in enumerate(mels):
|
| 162 |
-
fname = Path(fpaths[j]).with_suffix('.pt').name
|
| 163 |
-
fpath = Path(args.dataset_path, 'mels', fname)
|
| 164 |
-
torch.save(mel[:, :mel_lens[j]], fpath)
|
| 165 |
-
|
| 166 |
-
if args.extract_pitch:
|
| 167 |
-
for j, p in enumerate(pitch):
|
| 168 |
-
fname = Path(fpaths[j]).with_suffix('.pt').name
|
| 169 |
-
fpath = Path(args.dataset_path, 'pitch', fname)
|
| 170 |
-
torch.save(p[:mel_lens[j]], fpath)
|
| 171 |
-
|
| 172 |
-
if args.save_alignment_priors:
|
| 173 |
-
for j, prior in enumerate(attn_prior):
|
| 174 |
-
fname = Path(fpaths[j]).with_suffix('.pt').name
|
| 175 |
-
fpath = Path(args.dataset_path, 'alignment_priors', fname)
|
| 176 |
-
torch.save(prior[:mel_lens[j], :input_lens[j]], fpath)
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
if __name__ == '__main__':
|
| 180 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_training_cluster_s.sh
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
#SBATCH --job-name=train_fastpitch
|
| 3 |
-
#SBATCH --account=nn9866k
|
| 4 |
-
#SBATCH --time=11:50:00
|
| 5 |
-
#SBATCH --mem=16G
|
| 6 |
-
#SBATCH --partition=accel
|
| 7 |
-
#SBATCH --gres=gpu:1
|
| 8 |
-
|
| 9 |
-
# == Logging
|
| 10 |
-
|
| 11 |
-
#SBATCH --error="log_err" # Save the error messages
|
| 12 |
-
#SBATCH --output="log_out" # Save the stdout
|
| 13 |
-
|
| 14 |
-
## Set up job environment:
|
| 15 |
-
# set -o errexit # Exit the script on any error
|
| 16 |
-
# set -o nounset # Treat any unset variables as an error
|
| 17 |
-
|
| 18 |
-
## Activate environment
|
| 19 |
-
# source ~/.bashrc
|
| 20 |
-
|
| 21 |
-
eval "$(conda shell.bash hook)"
|
| 22 |
-
conda activate fastpitch
|
| 23 |
-
|
| 24 |
-
# Setup monitoring
|
| 25 |
-
nvidia-smi --query-gpu=timestamp,utilization.gpu,utilization.memory \
|
| 26 |
-
--format=csv --loop=1 > "gpu_util-$SLURM_JOB_ID.csv" &
|
| 27 |
-
NVIDIA_MONITOR_PID=$! # Capture PID of monitoring process
|
| 28 |
-
|
| 29 |
-
# Run our computation
|
| 30 |
-
bash scripts/train_2.sh
|
| 31 |
-
|
| 32 |
-
# After computation stop monitoring
|
| 33 |
-
kill -SIGINT "$NVIDIA_MONITOR_PID"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/docker/build.sh
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
docker build . -t fastpitch:latest
|
|
|
|
|
|
|
|
|
|
|
|
scripts/docker/interactive.sh
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
PORT=${PORT:-8888}
|
| 4 |
-
|
| 5 |
-
docker run --gpus=all -it --rm -e CUDA_VISIBLE_DEVICES --ipc=host -p $PORT:$PORT -v $PWD:/workspace/fastpitch/ fastpitch:latest bash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/download_cmudict.sh
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
set -e
|
| 4 |
-
|
| 5 |
-
: ${CMUDICT_DIR:="cmudict"}
|
| 6 |
-
|
| 7 |
-
if [ ! -f $CMUDICT_DIR/cmudict-0.7b ]; then
|
| 8 |
-
echo "Downloading cmudict-0.7b ..."
|
| 9 |
-
wget https://github.com/Alexir/CMUdict/raw/master/cmudict-0.7b -qO $CMUDICT_DIR/cmudict-0.7b
|
| 10 |
-
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/download_dataset.sh
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
set -e
|
| 4 |
-
|
| 5 |
-
scripts/download_cmudict.sh
|
| 6 |
-
|
| 7 |
-
DATA_DIR="LJSpeech-1.1"
|
| 8 |
-
LJS_ARCH="LJSpeech-1.1.tar.bz2"
|
| 9 |
-
LJS_URL="http://data.keithito.com/data/speech/${LJS_ARCH}"
|
| 10 |
-
|
| 11 |
-
if [ ! -d ${DATA_DIR} ]; then
|
| 12 |
-
echo "Downloading ${LJS_ARCH} ..."
|
| 13 |
-
wget -q ${LJS_URL}
|
| 14 |
-
echo "Extracting ${LJS_ARCH} ..."
|
| 15 |
-
tar jxvf ${LJS_ARCH}
|
| 16 |
-
rm -f ${LJS_ARCH}
|
| 17 |
-
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/download_models.sh
DELETED
|
@@ -1,63 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
set -e
|
| 4 |
-
|
| 5 |
-
MODEL_NAMES="$@"
|
| 6 |
-
[ -z "$MODEL_NAMES" ] && { echo "Usage: $0 [fastpitch|waveglow|hifigan|hifigan-finetuned-fastpitch]"; exit 1; }
|
| 7 |
-
|
| 8 |
-
function download_ngc_model() {
|
| 9 |
-
mkdir -p "$MODEL_DIR"
|
| 10 |
-
|
| 11 |
-
if [ ! -f "${MODEL_DIR}/${MODEL_ZIP}" ]; then
|
| 12 |
-
echo "Downloading ${MODEL_ZIP} ..."
|
| 13 |
-
wget --content-disposition -O ${MODEL_DIR}/${MODEL_ZIP} ${MODEL_URL} \
|
| 14 |
-
|| { echo "ERROR: Failed to download ${MODEL_ZIP} from NGC"; exit 1; }
|
| 15 |
-
fi
|
| 16 |
-
|
| 17 |
-
if [ ! -f "${MODEL_DIR}/${MODEL}" ]; then
|
| 18 |
-
echo "Extracting ${MODEL} ..."
|
| 19 |
-
unzip -qo ${MODEL_DIR}/${MODEL_ZIP} -d ${MODEL_DIR} \
|
| 20 |
-
|| { echo "ERROR: Failed to extract ${MODEL_ZIP}"; exit 1; }
|
| 21 |
-
|
| 22 |
-
echo "OK"
|
| 23 |
-
|
| 24 |
-
else
|
| 25 |
-
echo "${MODEL} already downloaded."
|
| 26 |
-
fi
|
| 27 |
-
|
| 28 |
-
}
|
| 29 |
-
|
| 30 |
-
for MODEL_NAME in $MODEL_NAMES
|
| 31 |
-
do
|
| 32 |
-
case $MODEL_NAME in
|
| 33 |
-
"fastpitch")
|
| 34 |
-
MODEL_DIR="pretrained_models/fastpitch"
|
| 35 |
-
MODEL_ZIP="fastpitch_pyt_fp32_ckpt_v1_1_21.05.0.zip"
|
| 36 |
-
MODEL="nvidia_fastpitch_210824.pt"
|
| 37 |
-
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/fastpitch_pyt_fp32_ckpt_v1_1/versions/21.05.0/zip"
|
| 38 |
-
;;
|
| 39 |
-
"hifigan")
|
| 40 |
-
MODEL_DIR="pretrained_models/hifigan"
|
| 41 |
-
MODEL_ZIP="hifigan__pyt_ckpt_ds-ljs22khz_21.08.0_amp.zip"
|
| 42 |
-
MODEL="hifigan_gen_checkpoint_6500.pt"
|
| 43 |
-
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/dle/hifigan__pyt_ckpt_ds-ljs22khz/versions/21.08.0_amp/zip"
|
| 44 |
-
;;
|
| 45 |
-
"hifigan-finetuned-fastpitch")
|
| 46 |
-
MODEL_DIR="pretrained_models/hifigan"
|
| 47 |
-
MODEL_ZIP="hifigan__pyt_ckpt_mode-finetune_ds-ljs22khz_21.08.0_amp.zip"
|
| 48 |
-
MODEL="hifigan_gen_checkpoint_10000_ft.pt"
|
| 49 |
-
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/dle/hifigan__pyt_ckpt_mode-finetune_ds-ljs22khz/versions/21.08.0_amp/zip"
|
| 50 |
-
;;
|
| 51 |
-
"waveglow")
|
| 52 |
-
MODEL_DIR="pretrained_models/waveglow"
|
| 53 |
-
MODEL_ZIP="waveglow_ckpt_amp_256_20.01.0.zip"
|
| 54 |
-
MODEL="nvidia_waveglow256pyt_fp16.pt"
|
| 55 |
-
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ckpt_amp_256/versions/20.01.0/zip"
|
| 56 |
-
;;
|
| 57 |
-
*)
|
| 58 |
-
echo "Unrecognized model: ${MODEL_NAME}"
|
| 59 |
-
exit 2
|
| 60 |
-
;;
|
| 61 |
-
esac
|
| 62 |
-
download_ngc_model "$MODEL_DIR" "$MODEL_ZIP" "$MODEL" "$MODEL_URL"
|
| 63 |
-
done
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/inference_benchmark.sh
DELETED
|
@@ -1,16 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
set -a
|
| 4 |
-
|
| 5 |
-
: ${FILELIST:="phrases/benchmark_8_128.tsv"}
|
| 6 |
-
: ${OUTPUT_DIR:="./output/audio_$(basename ${FILELIST} .tsv)"}
|
| 7 |
-
: ${TORCHSCRIPT:=true}
|
| 8 |
-
: ${BS_SEQUENCE:="1 4 8"}
|
| 9 |
-
: ${WARMUP:=64}
|
| 10 |
-
: ${REPEATS:=500}
|
| 11 |
-
: ${AMP:=false}
|
| 12 |
-
|
| 13 |
-
for BATCH_SIZE in $BS_SEQUENCE ; do
|
| 14 |
-
LOG_FILE="$OUTPUT_DIR"/perf-infer_amp-${AMP}_bs${BATCH_SIZE}.json
|
| 15 |
-
bash scripts/inference_example.sh "$@"
|
| 16 |
-
done
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/inference_example.sh
DELETED
|
@@ -1,78 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
export CUDNN_V8_API_ENABLED=1 # Keep the flag for older containers
|
| 4 |
-
export TORCH_CUDNN_V8_API_ENABLED=1
|
| 5 |
-
|
| 6 |
-
: ${DATASET_DIR:="sander_splits"}
|
| 7 |
-
: ${BATCH_SIZE:=1}
|
| 8 |
-
: ${FILELIST:="phrases/giehttjit.txt"}
|
| 9 |
-
: ${AMP:=false}
|
| 10 |
-
: ${TORCHSCRIPT:=true}
|
| 11 |
-
: ${WARMUP:=0}
|
| 12 |
-
: ${REPEATS:=1}
|
| 13 |
-
: ${CPU:=false}
|
| 14 |
-
: ${PHONE:=true}
|
| 15 |
-
|
| 16 |
-
# Paths to pre-trained models downloadable from NVIDIA NGC (LJSpeech-1.1)
|
| 17 |
-
FASTPITCH_LJ="output/FastPitch_checkpoint_660.pt"
|
| 18 |
-
HIFIGAN_LJ="pretrained_models/hifigan/hifigan_gen_checkpoint_10000_ft.pt"
|
| 19 |
-
WAVEGLOW_LJ="pretrained_models/waveglow/nvidia_waveglow256pyt_fp16.pt"
|
| 20 |
-
|
| 21 |
-
# Mel-spectrogram generator (optional; can synthesize from ground-truth spectrograms)
|
| 22 |
-
: ${FASTPITCH=$FASTPITCH_LJ}
|
| 23 |
-
|
| 24 |
-
# Vocoder (set only one)
|
| 25 |
-
#: ${HIFIGAN=$HIFIGAN_LJ}
|
| 26 |
-
: ${WAVEGLOW=$WAVEGLOW_LJ}
|
| 27 |
-
|
| 28 |
-
[[ "$FASTPITCH" == "$FASTPITCH_LJ" && ! -f "$FASTPITCH" ]] && { echo "Downloading $FASTPITCH from NGC..."; bash scripts/download_models.sh fastpitch; }
|
| 29 |
-
[[ "$WAVEGLOW" == "$WAVEGLOW_LJ" && ! -f "$WAVEGLOW" ]] && { echo "Downloading $WAVEGLOW from NGC..."; bash scripts/download_models.sh waveglow; }
|
| 30 |
-
[[ "$HIFIGAN" == "$HIFIGAN_LJ" && ! -f "$HIFIGAN" ]] && { echo "Downloading $HIFIGAN from NGC..."; bash scripts/download_models.sh hifigan-finetuned-fastpitch; }
|
| 31 |
-
|
| 32 |
-
if [[ "$HIFIGAN" == "$HIFIGAN_LJ" && "$FASTPITCH" != "$FASTPITCH_LJ" ]]; then
|
| 33 |
-
echo -e "\nNOTE: Using HiFi-GAN checkpoint trained for the LJSpeech-1.1 dataset."
|
| 34 |
-
echo -e "NOTE: If you're using a different dataset, consider training a new HiFi-GAN model or switch to WaveGlow."
|
| 35 |
-
echo -e "NOTE: See $0 for details.\n"
|
| 36 |
-
fi
|
| 37 |
-
|
| 38 |
-
# Synthesis
|
| 39 |
-
: ${SPEAKER:=0}
|
| 40 |
-
: ${DENOISING:=0.01}
|
| 41 |
-
|
| 42 |
-
if [ ! -n "$OUTPUT_DIR" ]; then
|
| 43 |
-
OUTPUT_DIR="./output/audio_$(basename ${FILELIST} .tsv)"
|
| 44 |
-
[ "$AMP" = true ] && OUTPUT_DIR+="_fp16"
|
| 45 |
-
[ "$AMP" = false ] && OUTPUT_DIR+="_fp32"
|
| 46 |
-
[ -n "$FASTPITCH" ] && OUTPUT_DIR+="_fastpitch"
|
| 47 |
-
[ ! -n "$FASTPITCH" ] && OUTPUT_DIR+="_gt-mel"
|
| 48 |
-
[ -n "$WAVEGLOW" ] && OUTPUT_DIR+="_waveglow"
|
| 49 |
-
[ -n "$HIFIGAN" ] && OUTPUT_DIR+="_hifigan"
|
| 50 |
-
OUTPUT_DIR+="_denoise-"${DENOISING}
|
| 51 |
-
fi
|
| 52 |
-
: ${LOG_FILE:="$OUTPUT_DIR/nvlog_infer.json"}
|
| 53 |
-
mkdir -p "$OUTPUT_DIR"
|
| 54 |
-
|
| 55 |
-
echo -e "\nAMP=$AMP, batch_size=$BATCH_SIZE\n"
|
| 56 |
-
|
| 57 |
-
ARGS=""
|
| 58 |
-
ARGS+=" --cuda"
|
| 59 |
-
# ARGS+=" --cudnn-benchmark" # Enable for benchmarking or long operation
|
| 60 |
-
ARGS+=" --dataset-path $DATASET_DIR"
|
| 61 |
-
ARGS+=" -i $FILELIST"
|
| 62 |
-
ARGS+=" -o $OUTPUT_DIR"
|
| 63 |
-
ARGS+=" --log-file $LOG_FILE"
|
| 64 |
-
ARGS+=" --batch-size $BATCH_SIZE"
|
| 65 |
-
ARGS+=" --denoising-strength $DENOISING"
|
| 66 |
-
ARGS+=" --warmup-steps $WARMUP"
|
| 67 |
-
ARGS+=" --repeats $REPEATS"
|
| 68 |
-
ARGS+=" --speaker $SPEAKER"
|
| 69 |
-
[ "$CPU" = false ] && ARGS+=" --cuda"
|
| 70 |
-
[ "$CPU" = false ] && ARGS+=" --cudnn-benchmark"
|
| 71 |
-
[ "$AMP" = true ] && ARGS+=" --amp"
|
| 72 |
-
[ "$TORCHSCRIPT" = true ] && ARGS+=" --torchscript"
|
| 73 |
-
[ -n "$HIFIGAN" ] && ARGS+=" --hifigan $HIFIGAN"
|
| 74 |
-
[ -n "$WAVEGLOW" ] && ARGS+=" --waveglow $WAVEGLOW"
|
| 75 |
-
[ -n "$FASTPITCH" ] && ARGS+=" --fastpitch $FASTPITCH"
|
| 76 |
-
[ "$PHONE" = true ] && ARGS+=" --p-arpabet 1.0"
|
| 77 |
-
|
| 78 |
-
python inference.py $ARGS "$@"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/prepare_dataset.sh
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
set -e
|
| 4 |
-
|
| 5 |
-
: ${DATA_DIR:=ALL_SAMI}
|
| 6 |
-
: ${ARGS="--extract-mels"}
|
| 7 |
-
|
| 8 |
-
python prepare_dataset.py \
|
| 9 |
-
--wav-text-filelists filelists/all_sami_filelist_shuf_200_train.txt \
|
| 10 |
-
--n-workers 8 \
|
| 11 |
-
--batch-size 1 \
|
| 12 |
-
--dataset-path $DATA_DIR \
|
| 13 |
-
--extract-pitch \
|
| 14 |
-
--f0-method pyin \
|
| 15 |
-
--pitch_mean 150\
|
| 16 |
-
--pitch_std 40\
|
| 17 |
-
--n-speakers 10 \
|
| 18 |
-
--n-languages 3 \
|
| 19 |
-
$ARGS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/train.sh
DELETED
|
@@ -1,100 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
export OMP_NUM_THREADS=1
|
| 4 |
-
|
| 5 |
-
: ${NUM_GPUS:=8}
|
| 6 |
-
: ${BATCH_SIZE:=16}
|
| 7 |
-
: ${GRAD_ACCUMULATION:=2}
|
| 8 |
-
: ${OUTPUT_DIR:="./output"}
|
| 9 |
-
: ${LOG_FILE:=$OUTPUT_DIR/nvlog.json}
|
| 10 |
-
: ${DATASET_PATH:=LJSpeech-1.1}
|
| 11 |
-
: ${TRAIN_FILELIST:=filelists/ljs_audio_pitch_text_train_v3.txt}
|
| 12 |
-
: ${VAL_FILELIST:=filelists/ljs_audio_pitch_text_val.txt}
|
| 13 |
-
: ${AMP:=false}
|
| 14 |
-
: ${SEED:=""}
|
| 15 |
-
|
| 16 |
-
: ${LEARNING_RATE:=0.1}
|
| 17 |
-
|
| 18 |
-
# Adjust these when the amount of data changes
|
| 19 |
-
: ${EPOCHS:=1000}
|
| 20 |
-
: ${EPOCHS_PER_CHECKPOINT:=20}
|
| 21 |
-
: ${WARMUP_STEPS:=1000}
|
| 22 |
-
: ${KL_LOSS_WARMUP:=100}
|
| 23 |
-
|
| 24 |
-
# Train a mixed phoneme/grapheme model
|
| 25 |
-
: ${PHONE:=true}
|
| 26 |
-
# Enable energy conditioning
|
| 27 |
-
: ${ENERGY:=true}
|
| 28 |
-
: ${TEXT_CLEANERS:=english_cleaners_v2}
|
| 29 |
-
# Add dummy space prefix/suffix is audio is not precisely trimmed
|
| 30 |
-
: ${APPEND_SPACES:=false}
|
| 31 |
-
|
| 32 |
-
: ${LOAD_PITCH_FROM_DISK:=true}
|
| 33 |
-
: ${LOAD_MEL_FROM_DISK:=false}
|
| 34 |
-
|
| 35 |
-
# For multispeaker models, add speaker ID = {0, 1, ...} as the last filelist column
|
| 36 |
-
: ${NSPEAKERS:=1}
|
| 37 |
-
: ${SAMPLING_RATE:=22050}
|
| 38 |
-
|
| 39 |
-
# Adjust env variables to maintain the global batch size: NUM_GPUS x BATCH_SIZE x GRAD_ACCUMULATION = 256.
|
| 40 |
-
GBS=$(($NUM_GPUS * $BATCH_SIZE * $GRAD_ACCUMULATION))
|
| 41 |
-
[ $GBS -ne 256 ] && echo -e "\nWARNING: Global batch size changed from 256 to ${GBS}."
|
| 42 |
-
echo -e "\nAMP=$AMP, ${NUM_GPUS}x${BATCH_SIZE}x${GRAD_ACCUMULATION}" \
|
| 43 |
-
"(global batch size ${GBS})\n"
|
| 44 |
-
|
| 45 |
-
ARGS=""
|
| 46 |
-
ARGS+=" --cuda"
|
| 47 |
-
ARGS+=" -o $OUTPUT_DIR"
|
| 48 |
-
ARGS+=" --log-file $LOG_FILE"
|
| 49 |
-
ARGS+=" --dataset-path $DATASET_PATH"
|
| 50 |
-
ARGS+=" --training-files $TRAIN_FILELIST"
|
| 51 |
-
ARGS+=" --validation-files $VAL_FILELIST"
|
| 52 |
-
ARGS+=" -bs $BATCH_SIZE"
|
| 53 |
-
ARGS+=" --grad-accumulation $GRAD_ACCUMULATION"
|
| 54 |
-
ARGS+=" --optimizer lamb"
|
| 55 |
-
ARGS+=" --epochs $EPOCHS"
|
| 56 |
-
ARGS+=" --epochs-per-checkpoint $EPOCHS_PER_CHECKPOINT"
|
| 57 |
-
ARGS+=" --resume"
|
| 58 |
-
ARGS+=" --warmup-steps $WARMUP_STEPS"
|
| 59 |
-
ARGS+=" -lr $LEARNING_RATE"
|
| 60 |
-
ARGS+=" --weight-decay 1e-6"
|
| 61 |
-
ARGS+=" --grad-clip-thresh 1000.0"
|
| 62 |
-
ARGS+=" --dur-predictor-loss-scale 0.1"
|
| 63 |
-
ARGS+=" --pitch-predictor-loss-scale 0.1"
|
| 64 |
-
ARGS+=" --trainloader-repeats 100"
|
| 65 |
-
ARGS+=" --validation-freq 10"
|
| 66 |
-
|
| 67 |
-
# Autoalign & new features
|
| 68 |
-
ARGS+=" --kl-loss-start-epoch 0"
|
| 69 |
-
ARGS+=" --kl-loss-warmup-epochs $KL_LOSS_WARMUP"
|
| 70 |
-
ARGS+=" --text-cleaners $TEXT_CLEANERS"
|
| 71 |
-
ARGS+=" --n-speakers $NSPEAKERS"
|
| 72 |
-
|
| 73 |
-
[ "$AMP" = "true" ] && ARGS+=" --amp"
|
| 74 |
-
[ "$PHONE" = "true" ] && ARGS+=" --p-arpabet 1.0"
|
| 75 |
-
[ "$ENERGY" = "true" ] && ARGS+=" --energy-conditioning"
|
| 76 |
-
[ "$SEED" != "" ] && ARGS+=" --seed $SEED"
|
| 77 |
-
[ "$LOAD_MEL_FROM_DISK" = true ] && ARGS+=" --load-mel-from-disk"
|
| 78 |
-
[ "$LOAD_PITCH_FROM_DISK" = true ] && ARGS+=" --load-pitch-from-disk"
|
| 79 |
-
[ "$PITCH_ONLINE_DIR" != "" ] && ARGS+=" --pitch-online-dir $PITCH_ONLINE_DIR" # e.g., /dev/shm/pitch
|
| 80 |
-
[ "$PITCH_ONLINE_METHOD" != "" ] && ARGS+=" --pitch-online-method $PITCH_ONLINE_METHOD"
|
| 81 |
-
[ "$APPEND_SPACES" = true ] && ARGS+=" --prepend-space-to-text"
|
| 82 |
-
[ "$APPEND_SPACES" = true ] && ARGS+=" --append-space-to-text"
|
| 83 |
-
|
| 84 |
-
if [ "$SAMPLING_RATE" == "44100" ]; then
|
| 85 |
-
ARGS+=" --sampling-rate 44100"
|
| 86 |
-
ARGS+=" --filter-length 2048"
|
| 87 |
-
ARGS+=" --hop-length 512"
|
| 88 |
-
ARGS+=" --win-length 2048"
|
| 89 |
-
ARGS+=" --mel-fmin 0.0"
|
| 90 |
-
ARGS+=" --mel-fmax 22050.0"
|
| 91 |
-
|
| 92 |
-
elif [ "$SAMPLING_RATE" != "22050" ]; then
|
| 93 |
-
echo "Unknown sampling rate $SAMPLING_RATE"
|
| 94 |
-
exit 1
|
| 95 |
-
fi
|
| 96 |
-
|
| 97 |
-
mkdir -p "$OUTPUT_DIR"
|
| 98 |
-
|
| 99 |
-
: ${DISTRIBUTED:="-m torch.distributed.launch --nproc_per_node $NUM_GPUS"}
|
| 100 |
-
python $DISTRIBUTED train.py $ARGS "$@"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/train_multilang.sh
DELETED
|
@@ -1,110 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
export OMP_NUM_THREADS=1
|
| 4 |
-
|
| 5 |
-
: ${NUM_GPUS:=1}
|
| 6 |
-
: ${BATCH_SIZE:=1}
|
| 7 |
-
: ${GRAD_ACCUMULATION:=32}
|
| 8 |
-
: ${OUTPUT_DIR:="./output_multilang"}
|
| 9 |
-
: ${LOG_FILE:=$OUTPUT_DIR/nvlog.json}
|
| 10 |
-
: ${DATASET_PATH:=ALL_SAMI}
|
| 11 |
-
#: ${DATASET_PATH:=mikal_urheim}
|
| 12 |
-
#: ${TRAIN_FILELIST:=filelists/smj_sander_text_noshorts_shuff_pitch.txt}
|
| 13 |
-
#: ${TRAIN_FILELIST:=filelists/mikal_urheim_pitch_shuf.txt}
|
| 14 |
-
: ${TRAIN_FILELIST:=filelists/all_sami_filelist_shuf_200_train.txt}
|
| 15 |
-
#: ${VAL_FILELIST:=filelists/smj_sander_text_noshorts_shuff_val_pitch.txt}
|
| 16 |
-
#: ${VAL_FILELIST:=filelists/mikal_urheim_pitch_shuf_val.txt}
|
| 17 |
-
: ${VAL_FILELIST:=filelists/all_sami_filelist_shuf_200_val.txt}
|
| 18 |
-
: ${AMP:=false}
|
| 19 |
-
: ${SEED:=""}
|
| 20 |
-
|
| 21 |
-
: ${LEARNING_RATE:=0.1}
|
| 22 |
-
|
| 23 |
-
# Adjust these when the amount of data changes
|
| 24 |
-
: ${EPOCHS:=1000}
|
| 25 |
-
: ${EPOCHS_PER_CHECKPOINT:=10}
|
| 26 |
-
: ${WARMUP_STEPS:=1000}
|
| 27 |
-
: ${KL_LOSS_WARMUP:=100}
|
| 28 |
-
|
| 29 |
-
# Train a mixed phoneme/grapheme model
|
| 30 |
-
: ${PHONE:=false}
|
| 31 |
-
# Enable energy conditioning
|
| 32 |
-
: ${ENERGY:=true}
|
| 33 |
-
: ${TEXT_CLEANERS:=basic_cleaners}
|
| 34 |
-
: ${SYMBOL_SET:=all_sami}
|
| 35 |
-
# Add dummy space prefix/suffix is audio is not precisely trimmed
|
| 36 |
-
: ${APPEND_SPACES:=false}
|
| 37 |
-
|
| 38 |
-
: ${LOAD_PITCH_FROM_DISK:=true} # was true
|
| 39 |
-
: ${LOAD_MEL_FROM_DISK:=false}
|
| 40 |
-
|
| 41 |
-
# For multispeaker models, add speaker ID = {0, 1, ...} as the last filelist column
|
| 42 |
-
: ${NSPEAKERS:=10} # 10
|
| 43 |
-
: ${NLANGUAGES:=3} # 3
|
| 44 |
-
: ${SAMPLING_RATE:=22050}
|
| 45 |
-
|
| 46 |
-
# Adjust env variables to maintain the global batch size: NUM_GPUS x BATCH_SIZE x GRAD_ACCUMULATION = 256.
|
| 47 |
-
GBS=$(($NUM_GPUS * $BATCH_SIZE * $GRAD_ACCUMULATION))
|
| 48 |
-
[ $GBS -ne 256 ] && echo -e "\nWARNING: Global batch size changed from 256 to ${GBS}."
|
| 49 |
-
echo -e "\nAMP=$AMP, ${NUM_GPUS}x${BATCH_SIZE}x${GRAD_ACCUMULATION}" \
|
| 50 |
-
"(global batch size ${GBS})\n"
|
| 51 |
-
|
| 52 |
-
ARGS=""
|
| 53 |
-
ARGS+=" --cuda"
|
| 54 |
-
ARGS+=" -o $OUTPUT_DIR"
|
| 55 |
-
ARGS+=" --log-file $LOG_FILE"
|
| 56 |
-
ARGS+=" --dataset-path $DATASET_PATH"
|
| 57 |
-
ARGS+=" --training-files $TRAIN_FILELIST"
|
| 58 |
-
ARGS+=" --validation-files $VAL_FILELIST"
|
| 59 |
-
ARGS+=" -bs $BATCH_SIZE"
|
| 60 |
-
ARGS+=" --grad-accumulation $GRAD_ACCUMULATION"
|
| 61 |
-
ARGS+=" --optimizer lamb" #adam
|
| 62 |
-
ARGS+=" --epochs $EPOCHS"
|
| 63 |
-
ARGS+=" --epochs-per-checkpoint $EPOCHS_PER_CHECKPOINT"
|
| 64 |
-
ARGS+=" --resume"
|
| 65 |
-
ARGS+=" --warmup-steps $WARMUP_STEPS"
|
| 66 |
-
ARGS+=" -lr $LEARNING_RATE"
|
| 67 |
-
ARGS+=" --weight-decay 1e-6"
|
| 68 |
-
ARGS+=" --grad-clip-thresh 1000.0"
|
| 69 |
-
ARGS+=" --dur-predictor-loss-scale 0.1"
|
| 70 |
-
ARGS+=" --pitch-predictor-loss-scale 0.1"
|
| 71 |
-
ARGS+=" --trainloader-repeats 100"
|
| 72 |
-
ARGS+=" --validation-freq 1" #10
|
| 73 |
-
|
| 74 |
-
# Autoalign & new features
|
| 75 |
-
ARGS+=" --kl-loss-start-epoch 0"
|
| 76 |
-
ARGS+=" --kl-loss-warmup-epochs $KL_LOSS_WARMUP"
|
| 77 |
-
ARGS+=" --text-cleaners $TEXT_CLEANERS"
|
| 78 |
-
ARGS+=" --n-speakers $NSPEAKERS"
|
| 79 |
-
ARGS+=" --n-languages $NLANGUAGES"
|
| 80 |
-
ARGS+=" --symbol-set $SYMBOL_SET"
|
| 81 |
-
|
| 82 |
-
[ "$AMP" = "true" ] && ARGS+=" --amp"
|
| 83 |
-
[ "$PHONE" = "true" ] && ARGS+=" --p-arpabet 1.0"
|
| 84 |
-
[ "$ENERGY" = "true" ] && ARGS+=" --energy-conditioning"
|
| 85 |
-
[ "$SEED" != "" ] && ARGS+=" --seed $SEED"
|
| 86 |
-
[ "$LOAD_MEL_FROM_DISK" = true ] && ARGS+=" --load-mel-from-disk"
|
| 87 |
-
[ "$LOAD_PITCH_FROM_DISK" = true ] && ARGS+=" --load-pitch-from-disk"
|
| 88 |
-
[ "$PITCH_ONLINE_DIR" != "" ] && ARGS+=" --pitch-online-dir $PITCH_ONLINE_DIR" # e.g., /dev/shm/pitch
|
| 89 |
-
[ "$PITCH_ONLINE_METHOD" != "" ] && ARGS+=" --pitch-online-method $PITCH_ONLINE_METHOD"
|
| 90 |
-
[ "$APPEND_SPACES" = true ] && ARGS+=" --prepend-space-to-text"
|
| 91 |
-
[ "$APPEND_SPACES" = true ] && ARGS+=" --append-space-to-text"
|
| 92 |
-
|
| 93 |
-
if [ "$SAMPLING_RATE" == "44100" ]; then
|
| 94 |
-
ARGS+=" --sampling-rate 44100"
|
| 95 |
-
ARGS+=" --filter-length 2048"
|
| 96 |
-
ARGS+=" --hop-length 512"
|
| 97 |
-
ARGS+=" --win-length 2048"
|
| 98 |
-
ARGS+=" --mel-fmin 0.0"
|
| 99 |
-
ARGS+=" --mel-fmax 22050.0"
|
| 100 |
-
|
| 101 |
-
elif [ "$SAMPLING_RATE" != "22050" ]; then
|
| 102 |
-
echo "Unknown sampling rate $SAMPLING_RATE"
|
| 103 |
-
exit 1
|
| 104 |
-
fi
|
| 105 |
-
|
| 106 |
-
mkdir -p "$OUTPUT_DIR"
|
| 107 |
-
|
| 108 |
-
: ${DISTRIBUTED:="-m torch.distributed.launch --nproc_per_node $NUM_GPUS"}
|
| 109 |
-
#python $DISTRIBUTED train.py $ARGS "$@"
|
| 110 |
-
python train_1_with_plot_multilang.py $ARGS "$@"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_1_with_plot_multilang.py
DELETED
|
@@ -1,593 +0,0 @@
|
|
| 1 |
-
# *****************************************************************************
|
| 2 |
-
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Redistribution and use in source and binary forms, with or without
|
| 5 |
-
# modification, are permitted provided that the following conditions are met:
|
| 6 |
-
# * Redistributions of source code must retain the above copyright
|
| 7 |
-
# notice, this list of conditions and the following disclaimer.
|
| 8 |
-
# * Redistributions in binary form must reproduce the above copyright
|
| 9 |
-
# notice, this list of conditions and the following disclaimer in the
|
| 10 |
-
# documentation and/or other materials provided with the distribution.
|
| 11 |
-
# * Neither the name of the NVIDIA CORPORATION nor the
|
| 12 |
-
# names of its contributors may be used to endorse or promote products
|
| 13 |
-
# derived from this software without specific prior written permission.
|
| 14 |
-
#
|
| 15 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 16 |
-
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 17 |
-
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 18 |
-
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
| 19 |
-
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
| 20 |
-
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
| 21 |
-
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
| 22 |
-
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 23 |
-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
| 24 |
-
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 25 |
-
#
|
| 26 |
-
# *****************************************************************************
|
| 27 |
-
|
| 28 |
-
import argparse
|
| 29 |
-
import copy
|
| 30 |
-
import os
|
| 31 |
-
import time
|
| 32 |
-
from collections import defaultdict, OrderedDict
|
| 33 |
-
from itertools import cycle
|
| 34 |
-
|
| 35 |
-
import numpy as np
|
| 36 |
-
import torch
|
| 37 |
-
import torch.distributed as dist
|
| 38 |
-
import amp_C
|
| 39 |
-
from apex.optimizers import FusedAdam, FusedLAMB
|
| 40 |
-
from torch.nn.parallel import DistributedDataParallel
|
| 41 |
-
from torch.utils.data import DataLoader
|
| 42 |
-
from torch.utils.data.distributed import DistributedSampler
|
| 43 |
-
|
| 44 |
-
import common.tb_dllogger as logger
|
| 45 |
-
import models
|
| 46 |
-
from common.tb_dllogger import log
|
| 47 |
-
from common.repeated_dataloader import (RepeatedDataLoader,
|
| 48 |
-
RepeatedDistributedSampler)
|
| 49 |
-
from common.text import cmudict
|
| 50 |
-
from common.utils import BenchmarkStats, Checkpointer, prepare_tmp
|
| 51 |
-
from fastpitch.attn_loss_function import AttentionBinarizationLoss
|
| 52 |
-
from fastpitch.data_function import batch_to_gpu, TTSCollate, TTSDataset
|
| 53 |
-
from fastpitch.loss_function import FastPitchLoss
|
| 54 |
-
|
| 55 |
-
import matplotlib.pyplot as plt
|
| 56 |
-
|
| 57 |
-
def parse_args(parser):
|
| 58 |
-
parser.add_argument('-o', '--output', type=str, required=True,
|
| 59 |
-
help='Directory to save checkpoints')
|
| 60 |
-
parser.add_argument('-d', '--dataset-path', type=str, default='./',
|
| 61 |
-
help='Path to dataset')
|
| 62 |
-
parser.add_argument('--log-file', type=str, default=None,
|
| 63 |
-
help='Path to a DLLogger log file')
|
| 64 |
-
|
| 65 |
-
train = parser.add_argument_group('training setup')
|
| 66 |
-
train.add_argument('--epochs', type=int, required=True,
|
| 67 |
-
help='Number of total epochs to run')
|
| 68 |
-
train.add_argument('--epochs-per-checkpoint', type=int, default=50,
|
| 69 |
-
help='Number of epochs per checkpoint')
|
| 70 |
-
train.add_argument('--checkpoint-path', type=str, default=None,
|
| 71 |
-
help='Checkpoint path to resume training')
|
| 72 |
-
train.add_argument('--keep-milestones', default=list(range(100, 1000, 100)),
|
| 73 |
-
type=int, nargs='+',
|
| 74 |
-
help='Milestone checkpoints to keep from removing')
|
| 75 |
-
train.add_argument('--resume', action='store_true',
|
| 76 |
-
help='Resume training from the last checkpoint')
|
| 77 |
-
train.add_argument('--seed', type=int, default=1234,
|
| 78 |
-
help='Seed for PyTorch random number generators')
|
| 79 |
-
train.add_argument('--amp', action='store_true',
|
| 80 |
-
help='Enable AMP')
|
| 81 |
-
train.add_argument('--cuda', action='store_true',
|
| 82 |
-
help='Run on GPU using CUDA')
|
| 83 |
-
train.add_argument('--cudnn-benchmark', action='store_true',
|
| 84 |
-
help='Enable cudnn benchmark mode')
|
| 85 |
-
train.add_argument('--ema-decay', type=float, default=0,
|
| 86 |
-
help='Discounting factor for training weights EMA')
|
| 87 |
-
train.add_argument('--grad-accumulation', type=int, default=1,
|
| 88 |
-
help='Training steps to accumulate gradients for')
|
| 89 |
-
train.add_argument('--kl-loss-start-epoch', type=int, default=250,
|
| 90 |
-
help='Start adding the hard attention loss term')
|
| 91 |
-
train.add_argument('--kl-loss-warmup-epochs', type=int, default=100,
|
| 92 |
-
help='Gradually increase the hard attention loss term')
|
| 93 |
-
train.add_argument('--kl-loss-weight', type=float, default=1.0,
|
| 94 |
-
help='Gradually increase the hard attention loss term')
|
| 95 |
-
train.add_argument('--benchmark-epochs-num', type=int, default=20,
|
| 96 |
-
help='Number of epochs for calculating final stats')
|
| 97 |
-
train.add_argument('--validation-freq', type=int, default=1,
|
| 98 |
-
help='Validate every N epochs to use less compute')
|
| 99 |
-
|
| 100 |
-
opt = parser.add_argument_group('optimization setup')
|
| 101 |
-
opt.add_argument('--optimizer', type=str, default='lamb',
|
| 102 |
-
help='Optimization algorithm')
|
| 103 |
-
opt.add_argument('-lr', '--learning-rate', type=float, required=True,
|
| 104 |
-
help='Learing rate')
|
| 105 |
-
opt.add_argument('--weight-decay', default=1e-6, type=float,
|
| 106 |
-
help='Weight decay')
|
| 107 |
-
opt.add_argument('--grad-clip-thresh', default=1000.0, type=float,
|
| 108 |
-
help='Clip threshold for gradients')
|
| 109 |
-
opt.add_argument('-bs', '--batch-size', type=int, required=True,
|
| 110 |
-
help='Batch size per GPU')
|
| 111 |
-
opt.add_argument('--warmup-steps', type=int, default=1000,
|
| 112 |
-
help='Number of steps for lr warmup')
|
| 113 |
-
opt.add_argument('--dur-predictor-loss-scale', type=float,
|
| 114 |
-
default=1.0, help='Rescale duration predictor loss')
|
| 115 |
-
opt.add_argument('--pitch-predictor-loss-scale', type=float,
|
| 116 |
-
default=1.0, help='Rescale pitch predictor loss')
|
| 117 |
-
opt.add_argument('--attn-loss-scale', type=float,
|
| 118 |
-
default=1.0, help='Rescale alignment loss')
|
| 119 |
-
|
| 120 |
-
data = parser.add_argument_group('dataset parameters')
|
| 121 |
-
data.add_argument('--training-files', type=str, nargs='*', required=True,
|
| 122 |
-
help='Paths to training filelists.')
|
| 123 |
-
data.add_argument('--validation-files', type=str, nargs='*',
|
| 124 |
-
required=True, help='Paths to validation filelists')
|
| 125 |
-
data.add_argument('--text-cleaners', nargs='*',
|
| 126 |
-
default=['english_cleaners'], type=str,
|
| 127 |
-
help='Type of text cleaners for input text')
|
| 128 |
-
data.add_argument('--symbol-set', type=str, default='english_basic',
|
| 129 |
-
help='Define symbol set for input text')
|
| 130 |
-
data.add_argument('--p-arpabet', type=float, default=0.0,
|
| 131 |
-
help='Probability of using arpabets instead of graphemes '
|
| 132 |
-
'for each word; set 0 for pure grapheme training')
|
| 133 |
-
data.add_argument('--heteronyms-path', type=str, default='cmudict/heteronyms',
|
| 134 |
-
help='Path to the list of heteronyms')
|
| 135 |
-
data.add_argument('--cmudict-path', type=str, default='cmudict/cmudict-0.7b',
|
| 136 |
-
help='Path to the pronouncing dictionary')
|
| 137 |
-
data.add_argument('--prepend-space-to-text', action='store_true',
|
| 138 |
-
help='Capture leading silence with a space token')
|
| 139 |
-
data.add_argument('--append-space-to-text', action='store_true',
|
| 140 |
-
help='Capture trailing silence with a space token')
|
| 141 |
-
data.add_argument('--num-workers', type=int, default=2, # 6
|
| 142 |
-
help='Subprocesses for train and val DataLoaders')
|
| 143 |
-
data.add_argument('--trainloader-repeats', type=int, default=100,
|
| 144 |
-
help='Repeats the dataset to prolong epochs')
|
| 145 |
-
|
| 146 |
-
cond = parser.add_argument_group('data for conditioning')
|
| 147 |
-
cond.add_argument('--n-speakers', type=int, default=1,
|
| 148 |
-
help='Number of speakers in the dataset. '
|
| 149 |
-
'n_speakers > 1 enables speaker embeddings')
|
| 150 |
-
# ANT: added language
|
| 151 |
-
cond.add_argument('--n-languages', type=int, default=1,
|
| 152 |
-
help='Number of languages in the dataset. '
|
| 153 |
-
'n_languages > 1 enables language embeddings')
|
| 154 |
-
|
| 155 |
-
cond.add_argument('--load-pitch-from-disk', action='store_true',
|
| 156 |
-
help='Use pitch cached on disk with prepare_dataset.py')
|
| 157 |
-
cond.add_argument('--pitch-online-method', default='pyin',
|
| 158 |
-
choices=['pyin'],
|
| 159 |
-
help='Calculate pitch on the fly during trainig')
|
| 160 |
-
cond.add_argument('--pitch-online-dir', type=str, default=None,
|
| 161 |
-
help='A directory for storing pitch calculated on-line')
|
| 162 |
-
cond.add_argument('--pitch-mean', type=float, default=125.626816, #default=214.72203,
|
| 163 |
-
help='Normalization value for pitch')
|
| 164 |
-
cond.add_argument('--pitch-std', type=float, default=37.52, #default=65.72038,
|
| 165 |
-
help='Normalization value for pitch')
|
| 166 |
-
cond.add_argument('--load-mel-from-disk', action='store_true',
|
| 167 |
-
help='Use mel-spectrograms cache on the disk') # XXX
|
| 168 |
-
|
| 169 |
-
audio = parser.add_argument_group('audio parameters')
|
| 170 |
-
audio.add_argument('--max-wav-value', default=32768.0, type=float,
|
| 171 |
-
help='Maximum audiowave value')
|
| 172 |
-
audio.add_argument('--sampling-rate', default=22050, type=int,
|
| 173 |
-
help='Sampling rate')
|
| 174 |
-
audio.add_argument('--filter-length', default=1024, type=int,
|
| 175 |
-
help='Filter length')
|
| 176 |
-
audio.add_argument('--hop-length', default=256, type=int,
|
| 177 |
-
help='Hop (stride) length')
|
| 178 |
-
audio.add_argument('--win-length', default=1024, type=int,
|
| 179 |
-
help='Window length')
|
| 180 |
-
audio.add_argument('--mel-fmin', default=0.0, type=float,
|
| 181 |
-
help='Minimum mel frequency')
|
| 182 |
-
audio.add_argument('--mel-fmax', default=8000.0, type=float,
|
| 183 |
-
help='Maximum mel frequency')
|
| 184 |
-
|
| 185 |
-
dist = parser.add_argument_group('distributed setup')
|
| 186 |
-
dist.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0),
|
| 187 |
-
help='Rank of the process for multiproc; do not set manually')
|
| 188 |
-
dist.add_argument('--world_size', type=int, default=os.getenv('WORLD_SIZE', 1),
|
| 189 |
-
help='Number of processes for multiproc; do not set manually')
|
| 190 |
-
return parser
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
def reduce_tensor(tensor, num_gpus):
|
| 194 |
-
rt = tensor.clone()
|
| 195 |
-
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
| 196 |
-
return rt.true_divide(num_gpus)
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
def init_distributed(args, world_size, rank):
|
| 200 |
-
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
|
| 201 |
-
print("Initializing distributed training")
|
| 202 |
-
|
| 203 |
-
# Set cuda device so everything is done on the right GPU.
|
| 204 |
-
torch.cuda.set_device(rank % torch.cuda.device_count())
|
| 205 |
-
|
| 206 |
-
# Initialize distributed communication
|
| 207 |
-
dist.init_process_group(backend=('nccl' if args.cuda else 'gloo'),
|
| 208 |
-
init_method='env://')
|
| 209 |
-
print("Done initializing distributed training")
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
def validate(model, epoch, total_iter, criterion, val_loader, distributed_run,
|
| 213 |
-
batch_to_gpu, local_rank, ema=False):
|
| 214 |
-
was_training = model.training
|
| 215 |
-
model.eval()
|
| 216 |
-
|
| 217 |
-
tik = time.perf_counter()
|
| 218 |
-
with torch.no_grad():
|
| 219 |
-
val_meta = defaultdict(float)
|
| 220 |
-
val_num_frames = 0
|
| 221 |
-
for i, batch in enumerate(val_loader):
|
| 222 |
-
x, y, num_frames = batch_to_gpu(batch)
|
| 223 |
-
y_pred = model(x)
|
| 224 |
-
loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum')
|
| 225 |
-
|
| 226 |
-
if distributed_run:
|
| 227 |
-
for k, v in meta.items():
|
| 228 |
-
val_meta[k] += reduce_tensor(v, 1)
|
| 229 |
-
val_num_frames += reduce_tensor(num_frames.data, 1).item()
|
| 230 |
-
else:
|
| 231 |
-
for k, v in meta.items():
|
| 232 |
-
val_meta[k] += v
|
| 233 |
-
val_num_frames += num_frames.item()
|
| 234 |
-
|
| 235 |
-
# NOTE: ugly patch to visualize the first utterance of the validation corpus.
|
| 236 |
-
# The goal is to determine if the training is progressing properly
|
| 237 |
-
if (i == 0) and (local_rank == 0) and (not ema):
|
| 238 |
-
# Plot some debug information
|
| 239 |
-
fig, axs = plt.subplots(2, 2, figsize=(21,14))
|
| 240 |
-
|
| 241 |
-
# - Mel-spectrogram
|
| 242 |
-
pred_mel = y_pred[0][0, :, :].cpu().detach().numpy().astype(np.float32).T
|
| 243 |
-
orig_mel = y[0][0, :, :].cpu().detach().numpy().astype(np.float32)
|
| 244 |
-
axs[0,0].imshow(orig_mel, aspect='auto', origin='lower', interpolation='nearest')
|
| 245 |
-
axs[1,0].imshow(pred_mel, aspect='auto', origin='lower', interpolation='nearest')
|
| 246 |
-
|
| 247 |
-
# Prosody
|
| 248 |
-
f0_pred = y_pred[4][0, :].cpu().detach().numpy().astype(np.float32)
|
| 249 |
-
f0_ori = y_pred[5][0, :].cpu().detach().numpy().astype(np.float32)
|
| 250 |
-
axs[1,1].plot(f0_ori)
|
| 251 |
-
axs[1,1].plot(f0_pred)
|
| 252 |
-
|
| 253 |
-
# # Duration
|
| 254 |
-
# att_pred = y_pred[2][0, :].cpu().detach().numpy().astype(np.float32)
|
| 255 |
-
# att_ori = x[7][0,:].cpu().detach().numpy().astype(np.float32)
|
| 256 |
-
# axs[0,1].imshow(att_ori, aspect='auto', origin='lower', interpolation='nearest')
|
| 257 |
-
|
| 258 |
-
if not os.path.exists("debug_epoch/"):
|
| 259 |
-
os.makedirs("debug_epoch_laila/")
|
| 260 |
-
|
| 261 |
-
fig.savefig(f'debug_epoch/{epoch:06d}.png', bbox_inches='tight')
|
| 262 |
-
|
| 263 |
-
val_meta = {k: v / len(val_loader.dataset) for k, v in val_meta.items()}
|
| 264 |
-
|
| 265 |
-
val_meta['took'] = time.perf_counter() - tik
|
| 266 |
-
|
| 267 |
-
log((epoch,) if epoch is not None else (), tb_total_steps=total_iter,
|
| 268 |
-
subset='val_ema' if ema else 'val',
|
| 269 |
-
data=OrderedDict([
|
| 270 |
-
('loss', val_meta['loss'].item()),
|
| 271 |
-
('mel_loss', val_meta['mel_loss'].item()),
|
| 272 |
-
('frames/s', val_num_frames / val_meta['took']),
|
| 273 |
-
('took', val_meta['took'])]),
|
| 274 |
-
)
|
| 275 |
-
|
| 276 |
-
if was_training:
|
| 277 |
-
model.train()
|
| 278 |
-
return val_meta
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None):
|
| 282 |
-
if warmup_iters == 0:
|
| 283 |
-
scale = 1.0
|
| 284 |
-
elif total_iter > warmup_iters:
|
| 285 |
-
scale = 1. / (total_iter ** 0.5)
|
| 286 |
-
else:
|
| 287 |
-
scale = total_iter / (warmup_iters ** 1.5)
|
| 288 |
-
|
| 289 |
-
for param_group in opt.param_groups:
|
| 290 |
-
param_group['lr'] = learning_rate * scale
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
def apply_ema_decay(model, ema_model, decay):
|
| 294 |
-
if not decay:
|
| 295 |
-
return
|
| 296 |
-
st = model.state_dict()
|
| 297 |
-
add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
|
| 298 |
-
for k, v in ema_model.state_dict().items():
|
| 299 |
-
if add_module and not k.startswith('module.'):
|
| 300 |
-
k = 'module.' + k
|
| 301 |
-
v.copy_(decay * v + (1 - decay) * st[k])
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
def init_multi_tensor_ema(model, ema_model):
|
| 305 |
-
model_weights = list(model.state_dict().values())
|
| 306 |
-
ema_model_weights = list(ema_model.state_dict().values())
|
| 307 |
-
ema_overflow_buf = torch.cuda.IntTensor([0])
|
| 308 |
-
return model_weights, ema_model_weights, ema_overflow_buf
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
|
| 312 |
-
amp_C.multi_tensor_axpby(
|
| 313 |
-
65536, overflow_buf, [ema_weights, model_weights, ema_weights],
|
| 314 |
-
decay, 1-decay, -1)
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
def main():
|
| 318 |
-
parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
|
| 319 |
-
allow_abbrev=False)
|
| 320 |
-
parser = parse_args(parser)
|
| 321 |
-
args, _ = parser.parse_known_args()
|
| 322 |
-
|
| 323 |
-
if args.p_arpabet > 0.0:
|
| 324 |
-
cmudict.initialize(args.cmudict_path, args.heteronyms_path)
|
| 325 |
-
|
| 326 |
-
distributed_run = args.world_size > 1
|
| 327 |
-
|
| 328 |
-
torch.manual_seed(args.seed + args.local_rank)
|
| 329 |
-
np.random.seed(args.seed + args.local_rank)
|
| 330 |
-
|
| 331 |
-
if args.local_rank == 0:
|
| 332 |
-
if not os.path.exists(args.output):
|
| 333 |
-
os.makedirs(args.output)
|
| 334 |
-
|
| 335 |
-
log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
|
| 336 |
-
tb_subsets = ['train', 'val']
|
| 337 |
-
if args.ema_decay > 0.0:
|
| 338 |
-
tb_subsets.append('val_ema')
|
| 339 |
-
|
| 340 |
-
logger.init(log_fpath, args.output, enabled=(args.local_rank == 0),
|
| 341 |
-
tb_subsets=tb_subsets)
|
| 342 |
-
logger.parameters(vars(args), tb_subset='train')
|
| 343 |
-
|
| 344 |
-
parser = models.parse_model_args('FastPitch', parser)
|
| 345 |
-
args, unk_args = parser.parse_known_args()
|
| 346 |
-
if len(unk_args) > 0:
|
| 347 |
-
raise ValueError(f'Invalid options {unk_args}')
|
| 348 |
-
|
| 349 |
-
torch.backends.cudnn.benchmark = args.cudnn_benchmark
|
| 350 |
-
|
| 351 |
-
if distributed_run:
|
| 352 |
-
init_distributed(args, args.world_size, args.local_rank)
|
| 353 |
-
else:
|
| 354 |
-
if args.trainloader_repeats > 1:
|
| 355 |
-
print('WARNING: Disabled --trainloader-repeats, supported only for'
|
| 356 |
-
' multi-GPU data loading.')
|
| 357 |
-
args.trainloader_repeats = 1
|
| 358 |
-
|
| 359 |
-
device = torch.device('cuda' if args.cuda else 'cpu')
|
| 360 |
-
model_config = models.get_model_config('FastPitch', args)
|
| 361 |
-
model = models.get_model('FastPitch', model_config, device)
|
| 362 |
-
|
| 363 |
-
attention_kl_loss = AttentionBinarizationLoss()
|
| 364 |
-
|
| 365 |
-
# Store pitch mean/std as params to translate from Hz during inference
|
| 366 |
-
model.pitch_mean[0] = args.pitch_mean
|
| 367 |
-
model.pitch_std[0] = args.pitch_std
|
| 368 |
-
|
| 369 |
-
kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9,
|
| 370 |
-
weight_decay=args.weight_decay)
|
| 371 |
-
if args.optimizer == 'adam':
|
| 372 |
-
optimizer = FusedAdam(model.parameters(), **kw)
|
| 373 |
-
# optimizer = torch.optim.Adam(model.parameters(), **kw)
|
| 374 |
-
elif args.optimizer == 'lamb':
|
| 375 |
-
|
| 376 |
-
optimizer = FusedLAMB(model.parameters(), **kw)
|
| 377 |
-
# optimizer = torch.optim.Adam(model.parameters(), **kw)
|
| 378 |
-
else:
|
| 379 |
-
raise ValueError
|
| 380 |
-
|
| 381 |
-
scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
| 382 |
-
|
| 383 |
-
if args.ema_decay > 0:
|
| 384 |
-
ema_model = copy.deepcopy(model)
|
| 385 |
-
else:
|
| 386 |
-
ema_model = None
|
| 387 |
-
|
| 388 |
-
if distributed_run:
|
| 389 |
-
model = DistributedDataParallel(
|
| 390 |
-
model, device_ids=[args.local_rank], output_device=args.local_rank,
|
| 391 |
-
find_unused_parameters=True)
|
| 392 |
-
|
| 393 |
-
train_state = {'epoch': 1, 'total_iter': 1}
|
| 394 |
-
checkpointer = Checkpointer(args.output, args.keep_milestones)
|
| 395 |
-
|
| 396 |
-
checkpointer.maybe_load(model, optimizer, scaler, train_state, args,
|
| 397 |
-
ema_model)
|
| 398 |
-
|
| 399 |
-
start_epoch = train_state['epoch']
|
| 400 |
-
total_iter = train_state['total_iter']
|
| 401 |
-
|
| 402 |
-
criterion = FastPitchLoss(
|
| 403 |
-
dur_predictor_loss_scale=args.dur_predictor_loss_scale,
|
| 404 |
-
pitch_predictor_loss_scale=args.pitch_predictor_loss_scale,
|
| 405 |
-
attn_loss_scale=args.attn_loss_scale)
|
| 406 |
-
|
| 407 |
-
collate_fn = TTSCollate()
|
| 408 |
-
|
| 409 |
-
if args.local_rank == 0:
|
| 410 |
-
prepare_tmp(args.pitch_online_dir)
|
| 411 |
-
|
| 412 |
-
trainset = TTSDataset(audiopaths_and_text=args.training_files, **vars(args))
|
| 413 |
-
valset = TTSDataset(audiopaths_and_text=args.validation_files, **vars(args))
|
| 414 |
-
|
| 415 |
-
if distributed_run:
|
| 416 |
-
train_sampler = RepeatedDistributedSampler(args.trainloader_repeats,
|
| 417 |
-
trainset, drop_last=True)
|
| 418 |
-
val_sampler = DistributedSampler(valset)
|
| 419 |
-
shuffle = False
|
| 420 |
-
else:
|
| 421 |
-
train_sampler, val_sampler, shuffle = None, None, False ########### was True
|
| 422 |
-
|
| 423 |
-
# 4 workers are optimal on DGX-1 (from epoch 2 onwards)
|
| 424 |
-
kw = {'num_workers': args.num_workers, 'batch_size': args.batch_size,
|
| 425 |
-
'collate_fn': collate_fn}
|
| 426 |
-
train_loader = RepeatedDataLoader(args.trainloader_repeats, trainset,
|
| 427 |
-
shuffle=shuffle, drop_last=True,
|
| 428 |
-
sampler=train_sampler, pin_memory=True,
|
| 429 |
-
persistent_workers=True, **kw)
|
| 430 |
-
val_loader = DataLoader(valset, shuffle=False, sampler=val_sampler,
|
| 431 |
-
pin_memory=False, **kw)
|
| 432 |
-
if args.ema_decay:
|
| 433 |
-
mt_ema_params = init_multi_tensor_ema(model, ema_model)
|
| 434 |
-
|
| 435 |
-
model.train()
|
| 436 |
-
bmark_stats = BenchmarkStats()
|
| 437 |
-
|
| 438 |
-
torch.cuda.synchronize()
|
| 439 |
-
for epoch in range(start_epoch, args.epochs + 1):
|
| 440 |
-
epoch_start_time = time.perf_counter()
|
| 441 |
-
|
| 442 |
-
epoch_loss = 0.0
|
| 443 |
-
epoch_mel_loss = 0.0
|
| 444 |
-
epoch_num_frames = 0
|
| 445 |
-
epoch_frames_per_sec = 0.0
|
| 446 |
-
|
| 447 |
-
if distributed_run:
|
| 448 |
-
train_loader.sampler.set_epoch(epoch)
|
| 449 |
-
|
| 450 |
-
iter_loss = 0
|
| 451 |
-
iter_num_frames = 0
|
| 452 |
-
iter_meta = {}
|
| 453 |
-
iter_start_time = time.perf_counter()
|
| 454 |
-
|
| 455 |
-
epoch_iter = 1
|
| 456 |
-
for batch, accum_step in zip(train_loader,
|
| 457 |
-
cycle(range(1, args.grad_accumulation + 1))):
|
| 458 |
-
if accum_step == 1:
|
| 459 |
-
adjust_learning_rate(total_iter, optimizer, args.learning_rate,
|
| 460 |
-
args.warmup_steps)
|
| 461 |
-
|
| 462 |
-
model.zero_grad(set_to_none=True)
|
| 463 |
-
|
| 464 |
-
x, y, num_frames = batch_to_gpu(batch)
|
| 465 |
-
|
| 466 |
-
with torch.cuda.amp.autocast(enabled=args.amp):
|
| 467 |
-
y_pred = model(x)
|
| 468 |
-
loss, meta = criterion(y_pred, y)
|
| 469 |
-
|
| 470 |
-
if (args.kl_loss_start_epoch is not None
|
| 471 |
-
and epoch >= args.kl_loss_start_epoch):
|
| 472 |
-
|
| 473 |
-
if args.kl_loss_start_epoch == epoch and epoch_iter == 1:
|
| 474 |
-
print('Begin hard_attn loss')
|
| 475 |
-
|
| 476 |
-
_, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred
|
| 477 |
-
binarization_loss = attention_kl_loss(attn_hard, attn_soft)
|
| 478 |
-
kl_weight = min((epoch - args.kl_loss_start_epoch) / args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight
|
| 479 |
-
meta['kl_loss'] = binarization_loss.clone().detach() * kl_weight
|
| 480 |
-
loss += kl_weight * binarization_loss
|
| 481 |
-
|
| 482 |
-
else:
|
| 483 |
-
meta['kl_loss'] = torch.zeros_like(loss)
|
| 484 |
-
kl_weight = 0
|
| 485 |
-
binarization_loss = 0
|
| 486 |
-
|
| 487 |
-
loss /= args.grad_accumulation
|
| 488 |
-
meta = {k: v / args.grad_accumulation
|
| 489 |
-
for k, v in meta.items()}
|
| 490 |
-
|
| 491 |
-
if args.amp:
|
| 492 |
-
scaler.scale(loss).backward()
|
| 493 |
-
else:
|
| 494 |
-
loss.backward()
|
| 495 |
-
|
| 496 |
-
if distributed_run:
|
| 497 |
-
reduced_loss = reduce_tensor(loss.data, args.world_size).item()
|
| 498 |
-
reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
|
| 499 |
-
meta = {k: reduce_tensor(v, args.world_size) for k, v in meta.items()}
|
| 500 |
-
else:
|
| 501 |
-
reduced_loss = loss.item()
|
| 502 |
-
reduced_num_frames = num_frames.item()
|
| 503 |
-
if np.isnan(reduced_loss):
|
| 504 |
-
raise Exception("loss is NaN")
|
| 505 |
-
|
| 506 |
-
iter_loss += reduced_loss
|
| 507 |
-
iter_num_frames += reduced_num_frames
|
| 508 |
-
iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}
|
| 509 |
-
if accum_step % args.grad_accumulation == 0:
|
| 510 |
-
|
| 511 |
-
logger.log_grads_tb(total_iter, model)
|
| 512 |
-
if args.amp:
|
| 513 |
-
scaler.unscale_(optimizer)
|
| 514 |
-
torch.nn.utils.clip_grad_norm_(
|
| 515 |
-
model.parameters(), args.grad_clip_thresh)
|
| 516 |
-
scaler.step(optimizer)
|
| 517 |
-
scaler.update()
|
| 518 |
-
else:
|
| 519 |
-
torch.nn.utils.clip_grad_norm_(
|
| 520 |
-
model.parameters(), args.grad_clip_thresh)
|
| 521 |
-
optimizer.step()
|
| 522 |
-
|
| 523 |
-
if args.ema_decay > 0.0:
|
| 524 |
-
apply_multi_tensor_ema(args.ema_decay, *mt_ema_params)
|
| 525 |
-
|
| 526 |
-
iter_mel_loss = iter_meta['mel_loss'].item()
|
| 527 |
-
iter_kl_loss = iter_meta['kl_loss'].item()
|
| 528 |
-
iter_time = time.perf_counter() - iter_start_time
|
| 529 |
-
epoch_frames_per_sec += iter_num_frames / iter_time
|
| 530 |
-
epoch_loss += iter_loss
|
| 531 |
-
epoch_num_frames += iter_num_frames
|
| 532 |
-
epoch_mel_loss += iter_mel_loss
|
| 533 |
-
|
| 534 |
-
num_iters = len(train_loader) // args.grad_accumulation
|
| 535 |
-
log((epoch, epoch_iter, num_iters), tb_total_steps=total_iter,
|
| 536 |
-
subset='train', data=OrderedDict([
|
| 537 |
-
('loss', iter_loss),
|
| 538 |
-
('mel_loss', iter_mel_loss),
|
| 539 |
-
('kl_loss', iter_kl_loss),
|
| 540 |
-
('kl_weight', kl_weight),
|
| 541 |
-
('frames/s', iter_num_frames / iter_time),
|
| 542 |
-
('took', iter_time),
|
| 543 |
-
('lrate', optimizer.param_groups[0]['lr'])]),
|
| 544 |
-
)
|
| 545 |
-
|
| 546 |
-
iter_loss = 0
|
| 547 |
-
iter_num_frames = 0
|
| 548 |
-
iter_meta = {}
|
| 549 |
-
iter_start_time = time.perf_counter()
|
| 550 |
-
|
| 551 |
-
if epoch_iter == num_iters:
|
| 552 |
-
break
|
| 553 |
-
epoch_iter += 1
|
| 554 |
-
total_iter += 1
|
| 555 |
-
|
| 556 |
-
# Finished epoch
|
| 557 |
-
epoch_loss /= epoch_iter
|
| 558 |
-
epoch_mel_loss /= epoch_iter
|
| 559 |
-
epoch_time = time.perf_counter() - epoch_start_time
|
| 560 |
-
log((epoch,), tb_total_steps=None, subset='train_avg',
|
| 561 |
-
data=OrderedDict([
|
| 562 |
-
('loss', epoch_loss),
|
| 563 |
-
('mel_loss', epoch_mel_loss),
|
| 564 |
-
('frames/s', epoch_num_frames / epoch_time),
|
| 565 |
-
('took', epoch_time)]),
|
| 566 |
-
)
|
| 567 |
-
bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss,
|
| 568 |
-
epoch_time)
|
| 569 |
-
|
| 570 |
-
if epoch % args.validation_freq == 0:
|
| 571 |
-
validate(model, epoch, total_iter, criterion, val_loader,
|
| 572 |
-
distributed_run, batch_to_gpu, ema=False, local_rank=args.local_rank)
|
| 573 |
-
|
| 574 |
-
if args.ema_decay > 0:
|
| 575 |
-
validate(ema_model, epoch, total_iter, criterion, val_loader,
|
| 576 |
-
distributed_run, batch_to_gpu, args.local_rank, ema=True)
|
| 577 |
-
|
| 578 |
-
# save before making sched.step() for proper loading of LR
|
| 579 |
-
checkpointer.maybe_save(args, model, ema_model, optimizer, scaler,
|
| 580 |
-
epoch, total_iter, model_config)
|
| 581 |
-
logger.flush()
|
| 582 |
-
|
| 583 |
-
# Finished training
|
| 584 |
-
if len(bmark_stats) > 0:
|
| 585 |
-
log((), tb_total_steps=None, subset='train_avg',
|
| 586 |
-
data=bmark_stats.get(args.benchmark_epochs_num))
|
| 587 |
-
|
| 588 |
-
validate(model, None, total_iter, criterion, val_loader, distributed_run,
|
| 589 |
-
batch_to_gpu)
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
if __name__ == '__main__':
|
| 593 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|