katrihiovain commited on
Commit
95b5cf1
·
1 Parent(s): 2b63853

removed unnecessary files and updated app_py

Browse files
README1.md DELETED
@@ -1,15 +0,0 @@
1
- # FastPitchMulti
2
- Experimental multi-lingual FastPitch
3
-
4
- What's done:
5
- - [x] Conditioning on language and speaker labels
6
- - [x] Dataset and preprocessing of Sámi data
7
- - [x] Combined character set for the Sámi languages
8
- - [x] Train a model on Sámi languages
9
- - [x] Selecting Estonian data
10
- - [x] Processing Estonian data
11
- - [ ] Train a model on Sámi x 3, Finnish, Estonian
12
-
13
- Ideas:
14
- - Move the language embedding to the very beginning of the encoder
15
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -11,7 +11,7 @@ speakers={"aj0": 0,
11
  "aj1": 1,
12
  "am": 2,
13
  "bi": 3,
14
- "kd": 4,
15
  "ln": 5,
16
  "lo": 6,
17
  "ms": 7,
 
11
  "aj1": 1,
12
  "am": 2,
13
  "bi": 3,
14
+ #"kd": 4,
15
  "ln": 5,
16
  "lo": 6,
17
  "ms": 7,
common/text/symbols_2.py DELETED
@@ -1,64 +0,0 @@
1
- """ from https://github.com/keithito/tacotron """
2
-
3
- '''
4
- Defines the set of symbols used in text input to the model.
5
-
6
- The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
7
- from .cmudict import valid_symbols
8
-
9
-
10
- # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
11
- _arpabet = ['@' + s for s in valid_symbols]
12
-
13
-
14
- def get_symbols(symbol_set='english_basic'):
15
- if symbol_set == 'english_basic':
16
- _pad = '_'
17
- _punctuation = '!\'(),.:;? '
18
- _special = '-'
19
- _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
20
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
21
- elif symbol_set == 'english_basic_lowercase':
22
- _pad = '_'
23
- _punctuation = '!\'"(),.:;? '
24
- _special = '-'
25
- _letters = 'abcdefghijklmnopqrstuvwxyz'
26
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
27
- elif symbol_set == 'english_expanded':
28
- _punctuation = '!\'",.:;? '
29
- _math = '#%&*+-/[]()'
30
- _special = '_@©°½—₩€$'
31
- _accented = 'áçéêëñöøćž'
32
- _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
33
- symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
34
- elif symbol_set == 'smj_expanded':
35
- _punctuation = '!\'",.:;?- '
36
- _math = '#%&*+-/[]()'
37
- _special = '_@©°½—₩€$'
38
- # _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
39
- _accented = 'áçéêëñöø' #also north sámi letters...
40
- # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
41
- _letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
42
- # symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
43
- symbols = list(_punctuation + _letters) + _arpabet
44
- elif symbol_set == 'sme_expanded':
45
- _punctuation = '!\'",.:;?- '
46
- _math = '#%&*+-/[]()'
47
- _special = '_@©°½—₩€$'
48
- _accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
49
- # _accented = 'áçéêëñöø' #also north sámi letters...
50
- # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
51
- _letters = 'AÁÆÅÄBCDEFGHIJKLMNŊOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋoøöpqrstuvwxyz'
52
- # symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
53
- symbols = list(_punctuation + _letters) + _arpabet
54
- else:
55
- raise Exception("{} symbol set does not exist".format(symbol_set))
56
-
57
- return symbols
58
-
59
-
60
- def get_pad_idx(symbol_set='english_basic'):
61
- if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
62
- return 0
63
- else:
64
- raise Exception("{} symbol set not used yet".format(symbol_set))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
common/text/symbols_BACKUP_MARCH_2024.py DELETED
@@ -1,64 +0,0 @@
1
- """ from https://github.com/keithito/tacotron """
2
-
3
- '''
4
- Defines the set of symbols used in text input to the model.
5
-
6
- The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
7
- from .cmudict import valid_symbols
8
-
9
-
10
- # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
11
- _arpabet = ['@' + s for s in valid_symbols]
12
-
13
-
14
- def get_symbols(symbol_set='english_basic'):
15
- if symbol_set == 'english_basic':
16
- _pad = '_'
17
- _punctuation = '!\'(),.:;? '
18
- _special = '-'
19
- _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
20
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
21
- elif symbol_set == 'english_basic_lowercase':
22
- _pad = '_'
23
- _punctuation = '!\'"(),.:;? '
24
- _special = '-'
25
- _letters = 'abcdefghijklmnopqrstuvwxyz'
26
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
27
- elif symbol_set == 'english_expanded':
28
- _punctuation = '!\'",.:;? '
29
- _math = '#%&*+-/[]()'
30
- _special = '_@©°½—₩€$'
31
- _accented = 'áçéêëñöøćž'
32
- _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
33
- symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
34
- elif symbol_set == 'smj_expanded':
35
- _punctuation = '!\'",.:;?- '
36
- _math = '#%&*+-/[]()'
37
- _special = '_@©°½—₩€$'
38
- # _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
39
- _accented = 'áçéêëñöø' #also north sámi letters...
40
- # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
41
- _letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTŦUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz' ########################## Ŧ ########################
42
- # symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
43
- symbols = list(_punctuation + _letters) + _arpabet
44
- elif symbol_set == 'sme_expanded':
45
- _punctuation = '!\'",.:;?- '
46
- _math = '#%&*+-/[]()'
47
- _special = '_@©°½—₩€$'
48
- _accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
49
- # _accented = 'áçéêëñöø' #also north sámi letters...
50
- # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
51
- _letters = 'AÁÆÅÄBCČDĐEFGHIJKLMNŊOØÖPQRSŠTŦUVWXYZŽaáæåäbcčdđefghijklmnŋoøöpqrsštŧuvwxyzž'
52
- # symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
53
- symbols = list(_punctuation + _letters) + _arpabet
54
- else:
55
- raise Exception("{} symbol set does not exist".format(symbol_set))
56
-
57
- return symbols
58
-
59
-
60
- def get_pad_idx(symbol_set='english_basic'):
61
- if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
62
- return 0
63
- else:
64
- raise Exception("{} symbol set not used yet".format(symbol_set))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
common/text/symbols_ORIGINAL.py DELETED
@@ -1,65 +0,0 @@
1
- """ from https://github.com/keithito/tacotron """
2
-
3
- '''
4
- Defines the set of symbols used in text input to the model.
5
-
6
- The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
7
- from .cmudict import valid_symbols
8
-
9
-
10
- # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
11
- _arpabet = ['@' + s for s in valid_symbols]
12
-
13
-
14
- def get_symbols(symbol_set='english_basic'):
15
- if symbol_set == 'english_basic':
16
- _pad = '_'
17
- _punctuation = '!\'(),.:;? '
18
- _special = '-'
19
- _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
20
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
21
- elif symbol_set == 'english_basic_lowercase':
22
- _pad = '_'
23
- _punctuation = '!\'"(),.:;? '
24
- _special = '-'
25
- _letters = 'abcdefghijklmnopqrstuvwxyz'
26
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
27
- elif symbol_set == 'english_expanded':
28
- _punctuation = '!\'",.:;? '
29
- _math = '#%&*+-/[]()'
30
- _special = '_@©°½—₩€$'
31
- _accented = 'áçéêëñöøćž'
32
- _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
33
- symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
34
- elif symbol_set == 'smj_expanded':
35
- _punctuation = '!\'",.:;?- '
36
- _math = '#%&*+-/[]()'
37
- _special = '_@©°½—₩€$'
38
- # _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
39
- _accented = 'áçéêëñöø' #also north sámi letters...
40
- # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
41
- _letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTŦUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
42
- # symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
43
- symbols = list(_punctuation + _letters) + _arpabet
44
- elif symbol_set == 'sme_expanded':
45
- _punctuation = '!\'",.:;?- '
46
- _math = '#%&*+-/[]()'
47
- _special = '_@©°½—₩€$'
48
- _accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
49
- # _accented = 'áçéêëñöø' #also north sámi letters...
50
- # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
51
- _letters = 'AÁÆÅÄBCČDĐEFGHIJKLMNŊOØÖPQRSŠTŦUVWXYZŽaáæåäbcčdđefghijklmnŋoøöpqrsštŧuvwxyzž'
52
- # _letters = 'AÁÆÅÄBCDĐEFGHIJKLMNŊOØÖPQRSŠTŦUVWXYZŽaáæåäbcčdđefghijklmnŋoøöpqrsštŧuvwxyzž'
53
- # symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
54
- symbols = list(_punctuation + _letters) + _arpabet
55
- else:
56
- raise Exception("{} symbol set does not exist".format(symbol_set))
57
-
58
- return symbols
59
-
60
-
61
- def get_pad_idx(symbol_set='english_basic'):
62
- if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
63
- return 0
64
- else:
65
- raise Exception("{} symbol set not used yet".format(symbol_set))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
common/text/symbols_backup.py DELETED
@@ -1,54 +0,0 @@
1
- """ from https://github.com/keithito/tacotron """
2
-
3
- '''
4
- Defines the set of symbols used in text input to the model.
5
-
6
- The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
7
- from .cmudict import valid_symbols
8
-
9
-
10
- # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
11
- _arpabet = ['@' + s for s in valid_symbols]
12
-
13
-
14
- def get_symbols(symbol_set='english_basic'):
15
- if symbol_set == 'english_basic':
16
- _pad = '_'
17
- _punctuation = '!\'(),.:;? '
18
- _special = '-'
19
- _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
20
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
21
- elif symbol_set == 'english_basic_lowercase':
22
- _pad = '_'
23
- _punctuation = '!\'"(),.:;? '
24
- _special = '-'
25
- _letters = 'abcdefghijklmnopqrstuvwxyz'
26
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
27
- elif symbol_set == 'english_expanded':
28
- _punctuation = '!\'",.:;? '
29
- _math = '#%&*+-/[]()'
30
- _special = '_@©°½—₩€$'
31
- _accented = 'áçéêëñöøćž'
32
- _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
33
- symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
34
- elif symbol_set == 'smj_expanded':
35
- _punctuation = '!\'",.:;?- '
36
- _math = '#%&*+-/[]()'
37
- _special = '_@©°½—₩€$'
38
- # _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
39
- _accented = 'áçéêëñöø' #also north sámi letters...
40
- # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
41
- _letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
42
- # symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
43
- symbols = list(_punctuation + _letters) + _arpabet
44
- else:
45
- raise Exception("{} symbol set does not exist".format(symbol_set))
46
-
47
- return symbols
48
-
49
-
50
- def get_pad_idx(symbol_set='english_basic'):
51
- if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded'}:
52
- return 0
53
- else:
54
- raise Exception("{} symbol set not used yet".format(symbol_set))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
common/text/symbols_sme.py DELETED
@@ -1,64 +0,0 @@
1
- """ from https://github.com/keithito/tacotron """
2
-
3
- '''
4
- Defines the set of symbols used in text input to the model.
5
-
6
- The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
7
- from .cmudict import valid_symbols
8
-
9
-
10
- # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
11
- _arpabet = ['@' + s for s in valid_symbols]
12
-
13
-
14
- def get_symbols(symbol_set='english_basic'):
15
- if symbol_set == 'english_basic':
16
- _pad = '_'
17
- _punctuation = '!\'(),.:;? '
18
- _special = '-'
19
- _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
20
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
21
- elif symbol_set == 'english_basic_lowercase':
22
- _pad = '_'
23
- _punctuation = '!\'"(),.:;? '
24
- _special = '-'
25
- _letters = 'abcdefghijklmnopqrstuvwxyz'
26
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
27
- elif symbol_set == 'english_expanded':
28
- _punctuation = '!\'",.:;? '
29
- _math = '#%&*+-/[]()'
30
- _special = '_@©°½—₩€$'
31
- _accented = 'áçéêëñöøćž'
32
- _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
33
- symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
34
- elif symbol_set == 'smj_expanded':
35
- _punctuation = '!\'",.:;?- '
36
- _math = '#%&*+-/[]()'
37
- _special = '_@©°½—₩€$'
38
- # _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
39
- _accented = 'áçéêëñöø' #also north sámi letters...
40
- # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
41
- _letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
42
- # symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
43
- symbols = list(_punctuation + _letters) + _arpabet
44
- elif symbol_set == 'sme_expanded':
45
- _punctuation = '!\'",.:;?- '
46
- _math = '#%&*+-/[]()'
47
- _special = '_@©°½—₩€$'
48
- _accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
49
- # _accented = 'áçéêëñöø' #also north sámi letters...
50
- # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
51
- _letters = 'AÁÆÅÄBCDEFGHIJKLMNŊOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋoøöpqrstuvwxyz'
52
- # symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
53
- symbols = list(_punctuation + _letters) + _arpabet
54
- else:
55
- raise Exception("{} symbol set does not exist".format(symbol_set))
56
-
57
- return symbols
58
-
59
-
60
- def get_pad_idx(symbol_set='english_basic'):
61
- if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
62
- return 0
63
- else:
64
- raise Exception("{} symbol set not used yet".format(symbol_set))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
common/text/symbols_sme_1.py DELETED
@@ -1,64 +0,0 @@
1
- """ from https://github.com/keithito/tacotron """
2
-
3
- '''
4
- Defines the set of symbols used in text input to the model.
5
-
6
- The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
7
- from .cmudict import valid_symbols
8
-
9
-
10
- # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
11
- _arpabet = ['@' + s for s in valid_symbols]
12
-
13
-
14
- def get_symbols(symbol_set='english_basic'):
15
- if symbol_set == 'english_basic':
16
- _pad = '_'
17
- _punctuation = '!\'(),.:;? '
18
- _special = '-'
19
- _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
20
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
21
- elif symbol_set == 'english_basic_lowercase':
22
- _pad = '_'
23
- _punctuation = '!\'"(),.:;? '
24
- _special = '-'
25
- _letters = 'abcdefghijklmnopqrstuvwxyz'
26
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
27
- elif symbol_set == 'english_expanded':
28
- _punctuation = '!\'",.:;? '
29
- _math = '#%&*+-/[]()'
30
- _special = '_@©°½—₩€$'
31
- _accented = 'áçéêëñöøćž'
32
- _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
33
- symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
34
- elif symbol_set == 'smj_expanded':
35
- _punctuation = '!\'",.:;?- '
36
- _math = '#%&*+-/[]()'
37
- _special = '_@©°½—₩€$'
38
- # _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
39
- _accented = 'áçéêëñöø' #also north sámi letters...
40
- # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
41
- _letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
42
- # symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
43
- symbols = list(_punctuation + _letters) + _arpabet
44
- elif symbol_set == 'sme_expanded':
45
- _punctuation = '!\'",.:;?- '
46
- _math = '#%&*+-/[]()'
47
- _special = '_@©°½—₩€$'
48
- _accented = 'áçéêëńñöøćčžđšŧ' #also north sámi letters...
49
- # _accented = 'áçéêëñöø' #also north sámi letters...
50
- # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
51
- _letters = 'AÁÆÅÄBCDĐEFGHIJKLMNŊOØÖPQRSŠTUVWXYZŽaáæåäbcdđefghijklmnŋoøöpqrsštŧuvwxyzž'
52
- # symbols = list(_punctuation + _math + _special + _accented + _letters) #+ _arpabet
53
- symbols = list(_punctuation + _letters) + _arpabet
54
- else:
55
- raise Exception("{} symbol set does not exist".format(symbol_set))
56
-
57
- return symbols
58
-
59
-
60
- def get_pad_idx(symbol_set='english_basic'):
61
- if symbol_set in {'english_basic', 'english_basic_lowercase', 'smj_expanded', 'sme_expanded'}:
62
- return 0
63
- else:
64
- raise Exception("{} symbol set not used yet".format(symbol_set))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
common/text/symbols_smj.py DELETED
@@ -1,48 +0,0 @@
1
- """ from https://github.com/keithito/tacotron """
2
-
3
- '''
4
- Defines the set of symbols used in text input to the model.
5
-
6
- The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
7
- from .cmudict import valid_symbols
8
-
9
-
10
- # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
11
- _arpabet = ['@' + s for s in valid_symbols]
12
-
13
-
14
- def get_symbols(symbol_set='smj_basic'):
15
- if symbol_set == 'smj_basic':
16
- _pad = '_'
17
- _punctuation = '!\'(),.:;? '
18
- _special = '-'
19
- # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
20
- _letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
21
- # LETTERS = 'AaÁáBbCcDdEeFfGgHhIiJjKkLlMmNnŊŋOoPpRrSsTtUuVvZzŃńÑñÆæØøÅåÄäÖö'
22
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
23
- elif symbol_set == 'smj_basic_lowercase':
24
- _pad = '_'
25
- _punctuation = '!\'"(),.:;? '
26
- _special = '-'
27
- # _letters = 'abcdefghijklmnopqrstuvwxyz'
28
- _letters = 'aáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
29
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
30
- elif symbol_set == 'smj_expanded':
31
- _punctuation = '!\'",.:;? '
32
- _math = '#%&*+-/[]()'
33
- _special = '_@©°½—₩€$'
34
- _accented = 'áçéêëñöøćžđšŧ' #also north sámi letters...
35
- # _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
36
- _letters = 'AÁÆÅÄBCDEFGHIJKLMNŊŃÑOØÖPQRSTUVWXYZaáæåäbcdefghijklmnŋńñoøöpqrstuvwxyz'
37
- symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
38
- else:
39
- raise Exception("{} symbol set does not exist".format(symbol_set))
40
-
41
- return symbols
42
-
43
-
44
- def get_pad_idx(symbol_set='smj_basic'):
45
- if symbol_set in {'smj_basic', 'smj_basic_lowercase'}:
46
- return 0
47
- else:
48
- raise Exception("{} symbol set not used yet".format(symbol_set))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
common/utils_hfg.py DELETED
@@ -1,14 +0,0 @@
1
-
2
- ##############################################################################
3
- # Foreing utils.py from HiFi-GAN
4
- ##############################################################################
5
-
6
-
7
- def init_weights(m, mean=0.0, std=0.01):
8
- classname = m.__class__.__name__
9
- if classname.find("Conv") != -1:
10
- m.weight.data.normal_(mean, std)
11
-
12
-
13
- def get_padding(kernel_size, dilation=1):
14
- return int((kernel_size*dilation - dilation)/2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
common/utils_ok.py DELETED
@@ -1,291 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # MIT License
16
- #
17
- # Copyright (c) 2020 Jungil Kong
18
- #
19
- # Permission is hereby granted, free of charge, to any person obtaining a copy
20
- # of this software and associated documentation files (the "Software"), to deal
21
- # in the Software without restriction, including without limitation the rights
22
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
23
- # copies of the Software, and to permit persons to whom the Software is
24
- # furnished to do so, subject to the following conditions:
25
- #
26
- # The above copyright notice and this permission notice shall be included in all
27
- # copies or substantial portions of the Software.
28
- #
29
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
30
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
31
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
32
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
33
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
34
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35
- # SOFTWARE.
36
-
37
- # The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
38
- # init_weights, get_padding, AttrDict
39
-
40
- import ctypes
41
- import glob
42
- import os
43
- import re
44
- import shutil
45
- import warnings
46
- from collections import defaultdict, OrderedDict
47
- from pathlib import Path
48
- from typing import Optional
49
-
50
- import librosa
51
- import numpy as np
52
-
53
- import torch
54
- import torch.distributed as dist
55
- from scipy.io.wavfile import read
56
-
57
-
58
- def mask_from_lens(lens, max_len: Optional[int] = None):
59
- if max_len is None:
60
- max_len = lens.max()
61
- ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
62
- mask = torch.lt(ids, lens.unsqueeze(1))
63
- return mask
64
-
65
-
66
- def load_wav(full_path, torch_tensor=False):
67
- import soundfile # flac
68
- data, sampling_rate = soundfile.read(full_path, dtype='int16')
69
- if torch_tensor:
70
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
71
- else:
72
- return data, sampling_rate
73
-
74
-
75
- def load_wav_to_torch(full_path, force_sampling_rate=None):
76
- if force_sampling_rate is not None:
77
- data, sampling_rate = librosa.load(full_path, sr=force_sampling_rate)
78
- else:
79
- sampling_rate, data = read(full_path)
80
-
81
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
82
-
83
-
84
- def load_filepaths_and_text(dataset_path, fnames, has_speakers=False, split="|"):
85
- def split_line(root, line):
86
- parts = line.strip().split(split)
87
- if has_speakers:
88
- paths, non_paths = parts[:-2], parts[-2:]
89
- else:
90
- paths, non_paths = parts[:-1], parts[-1:]
91
- return tuple(str(Path(root, p)) for p in paths) + tuple(non_paths)
92
-
93
- fpaths_and_text = []
94
- for fname in fnames:
95
- with open(fname, encoding='utf-8') as f:
96
- fpaths_and_text += [split_line(dataset_path, line) for line in f]
97
- return fpaths_and_text
98
-
99
-
100
- def to_gpu(x):
101
- x = x.contiguous()
102
- return x.cuda(non_blocking=True) if torch.cuda.is_available() else x
103
-
104
-
105
- def l2_promote():
106
- _libcudart = ctypes.CDLL('libcudart.so')
107
- # Set device limit on the current device
108
- # cudaLimitMaxL2FetchGranularity = 0x05
109
- pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))
110
- _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
111
- _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
112
- assert pValue.contents.value == 128
113
-
114
-
115
- def prepare_tmp(path):
116
- if path is None:
117
- return
118
- p = Path(path)
119
- if p.is_dir():
120
- warnings.warn(f'{p} exists. Removing...')
121
- shutil.rmtree(p, ignore_errors=True)
122
- p.mkdir(parents=False, exist_ok=False)
123
-
124
-
125
- def print_once(*msg):
126
- if not dist.is_initialized() or dist.get_rank() == 0:
127
- print(*msg)
128
-
129
-
130
- def init_weights(m, mean=0.0, std=0.01):
131
- classname = m.__class__.__name__
132
- if classname.find("Conv") != -1:
133
- m.weight.data.normal_(mean, std)
134
-
135
-
136
- def get_padding(kernel_size, dilation=1):
137
- return int((kernel_size*dilation - dilation)/2)
138
-
139
-
140
- class AttrDict(dict):
141
- def __init__(self, *args, **kwargs):
142
- super(AttrDict, self).__init__(*args, **kwargs)
143
- self.__dict__ = self
144
-
145
-
146
- class DefaultAttrDict(defaultdict):
147
- def __init__(self, *args, **kwargs):
148
- super(DefaultAttrDict, self).__init__(*args, **kwargs)
149
- self.__dict__ = self
150
-
151
- def __getattr__(self, item):
152
- return self[item]
153
-
154
-
155
- class BenchmarkStats:
156
- """ Tracks statistics used for benchmarking. """
157
- def __init__(self):
158
- self.num_frames = []
159
- self.losses = []
160
- self.mel_losses = []
161
- self.took = []
162
-
163
- def update(self, num_frames, losses, mel_losses, took):
164
- self.num_frames.append(num_frames)
165
- self.losses.append(losses)
166
- self.mel_losses.append(mel_losses)
167
- self.took.append(took)
168
-
169
- def get(self, n_epochs):
170
- frames_s = sum(self.num_frames[-n_epochs:]) / sum(self.took[-n_epochs:])
171
- return {'frames/s': frames_s,
172
- 'loss': np.mean(self.losses[-n_epochs:]),
173
- 'mel_loss': np.mean(self.mel_losses[-n_epochs:]),
174
- 'took': np.mean(self.took[-n_epochs:]),
175
- 'benchmark_epochs_num': n_epochs}
176
-
177
- def __len__(self):
178
- return len(self.losses)
179
-
180
-
181
- class Checkpointer:
182
-
183
- def __init__(self, save_dir, keep_milestones=[]):
184
- self.save_dir = save_dir
185
- self.keep_milestones = keep_milestones
186
-
187
- find = lambda name: [
188
- (int(re.search("_(\d+).pt", fn).group(1)), fn)
189
- for fn in glob.glob(f"{save_dir}/{name}_checkpoint_*.pt")]
190
-
191
- tracked = sorted(find("FastPitch"), key=lambda t: t[0])
192
- self.tracked = OrderedDict(tracked)
193
-
194
- def last_checkpoint(self, output):
195
-
196
- def corrupted(fpath):
197
- try:
198
- torch.load(fpath, map_location="cpu")
199
- return False
200
- except:
201
- warnings.warn(f"Cannot load {fpath}")
202
- return True
203
-
204
- saved = sorted(
205
- glob.glob(f"{output}/FastPitch_checkpoint_*.pt"),
206
- key=lambda f: int(re.search("_(\d+).pt", f).group(1)))
207
-
208
- if len(saved) >= 1 and not corrupted(saved[-1]):
209
- return saved[-1]
210
- elif len(saved) >= 2:
211
- return saved[-2]
212
- else:
213
- return None
214
-
215
- def maybe_load(self, model, optimizer, scaler, train_state, args,
216
- ema_model=None):
217
-
218
- assert args.checkpoint_path is None or args.resume is False, (
219
- "Specify a single checkpoint source")
220
-
221
- fpath = None
222
- if args.checkpoint_path is not None:
223
- fpath = args.checkpoint_path
224
- self.tracked = OrderedDict() # Do not track/delete prev ckpts
225
- elif args.resume:
226
- fpath = self.last_checkpoint(args.output)
227
-
228
- if fpath is None:
229
- return
230
-
231
- print_once(f"Loading model and optimizer state from {fpath}")
232
- ckpt = torch.load(fpath, map_location="cpu")
233
- train_state["epoch"] = ckpt["epoch"] + 1
234
- train_state["total_iter"] = ckpt["iteration"]
235
-
236
- no_pref = lambda sd: {re.sub("^module.", "", k): v for k, v in sd.items()}
237
- unwrap = lambda m: getattr(m, "module", m)
238
-
239
- unwrap(model).load_state_dict(no_pref(ckpt["state_dict"]))
240
-
241
- if ema_model is not None:
242
- unwrap(ema_model).load_state_dict(no_pref(ckpt["ema_state_dict"]))
243
-
244
- optimizer.load_state_dict(ckpt["optimizer"])
245
-
246
- if "scaler" in ckpt:
247
- scaler.load_state_dict(ckpt["scaler"])
248
- else:
249
- warnings.warn("AMP scaler state missing from the checkpoint.")
250
-
251
- def maybe_save(self, args, model, ema_model, optimizer, scaler, epoch,
252
- total_iter, config):
253
-
254
- intermediate = (args.epochs_per_checkpoint > 0
255
- and epoch % args.epochs_per_checkpoint == 0)
256
- final = epoch == args.epochs
257
-
258
- if not intermediate and not final and epoch not in self.keep_milestones:
259
- return
260
-
261
- rank = 0
262
- if dist.is_initialized():
263
- dist.barrier()
264
- rank = dist.get_rank()
265
-
266
- if rank != 0:
267
- return
268
-
269
- unwrap = lambda m: getattr(m, "module", m)
270
- ckpt = {"epoch": epoch,
271
- "iteration": total_iter,
272
- "config": config,
273
- "train_setup": args.__dict__,
274
- "state_dict": unwrap(model).state_dict(),
275
- "optimizer": optimizer.state_dict(),
276
- "scaler": scaler.state_dict()}
277
- if ema_model is not None:
278
- ckpt["ema_state_dict"] = unwrap(ema_model).state_dict()
279
-
280
- fpath = Path(args.output, f"FastPitch_checkpoint_{epoch}.pt")
281
- print(f"Saving model and optimizer state at epoch {epoch} to {fpath}")
282
- torch.save(ckpt, fpath)
283
-
284
- # Remove old checkpoints; keep milestones and the last two
285
- self.tracked[epoch] = fpath
286
- for epoch in set(list(self.tracked)[:-2]) - set(self.keep_milestones):
287
- try:
288
- os.remove(self.tracked[epoch])
289
- except:
290
- pass
291
- del self.tracked[epoch]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fastpitch/data_function (copy).py.txt DELETED
@@ -1,425 +0,0 @@
1
- # *****************************************************************************
2
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
- #
4
- # Redistribution and use in source and binary forms, with or without
5
- # modification, are permitted provided that the following conditions are met:
6
- # * Redistributions of source code must retain the above copyright
7
- # notice, this list of conditions and the following disclaimer.
8
- # * Redistributions in binary form must reproduce the above copyright
9
- # notice, this list of conditions and the following disclaimer in the
10
- # documentation and/or other materials provided with the distribution.
11
- # * Neither the name of the NVIDIA CORPORATION nor the
12
- # names of its contributors may be used to endorse or promote products
13
- # derived from this software without specific prior written permission.
14
- #
15
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
- # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
- # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19
- # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
- # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
- # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22
- # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
- # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
- #
26
- # *****************************************************************************
27
-
28
- import functools
29
- import json
30
- import re
31
- from pathlib import Path
32
-
33
- import librosa
34
- import numpy as np
35
- import torch
36
- import torch.nn.functional as F
37
- from scipy import ndimage
38
- from scipy.stats import betabinom
39
-
40
- import common.layers as layers
41
- from common.text.text_processing import TextProcessing
42
- from common.utils import load_wav_to_torch, load_filepaths_and_text, to_gpu
43
-
44
-
45
- class BetaBinomialInterpolator:
46
- """Interpolates alignment prior matrices to save computation.
47
-
48
- Calculating beta-binomial priors is costly. Instead cache popular sizes
49
- and use img interpolation to get priors faster.
50
- """
51
- def __init__(self, round_mel_len_to=100, round_text_len_to=20):
52
- self.round_mel_len_to = round_mel_len_to
53
- self.round_text_len_to = round_text_len_to
54
- self.bank = functools.lru_cache(beta_binomial_prior_distribution)
55
-
56
- def round(self, val, to):
57
- return max(1, int(np.round((val + 1) / to))) * to
58
-
59
- def __call__(self, w, h):
60
- bw = self.round(w, to=self.round_mel_len_to)
61
- bh = self.round(h, to=self.round_text_len_to)
62
- ret = ndimage.zoom(self.bank(bw, bh).T, zoom=(w / bw, h / bh), order=1)
63
- assert ret.shape[0] == w, ret.shape
64
- assert ret.shape[1] == h, ret.shape
65
- return ret
66
-
67
-
68
- def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling=1.0):
69
- P = phoneme_count
70
- M = mel_count
71
- x = np.arange(0, P)
72
- mel_text_probs = []
73
- for i in range(1, M+1):
74
- a, b = scaling * i, scaling * (M + 1 - i)
75
- rv = betabinom(P, a, b)
76
- mel_i_prob = rv.pmf(x)
77
- mel_text_probs.append(mel_i_prob)
78
- return torch.tensor(np.array(mel_text_probs))
79
-
80
-
81
- def estimate_pitch(wav, mel_len, method='pyin', normalize_mean=None,
82
- normalize_std=None, n_formants=1):
83
-
84
- if type(normalize_mean) is float or type(normalize_mean) is list:
85
- normalize_mean = torch.tensor(normalize_mean)
86
-
87
- if type(normalize_std) is float or type(normalize_std) is list:
88
- normalize_std = torch.tensor(normalize_std)
89
-
90
- if method == 'pyin':
91
-
92
- snd, sr = librosa.load(wav)
93
- pitch_mel, voiced_flag, voiced_probs = librosa.pyin(
94
- snd, fmin=librosa.note_to_hz('C2'),
95
- # fmax=librosa.note_to_hz('C7'), frame_length=1024)
96
- fmax=400, frame_length=1024)
97
- assert np.abs(mel_len - pitch_mel.shape[0]) <= 1.0
98
-
99
- pitch_mel = np.where(np.isnan(pitch_mel), 0.0, pitch_mel)
100
- pitch_mel = torch.from_numpy(pitch_mel).unsqueeze(0)
101
- pitch_mel = F.pad(pitch_mel, (0, mel_len - pitch_mel.size(1)))
102
-
103
- if n_formants > 1:
104
- raise NotImplementedError
105
-
106
- else:
107
- raise ValueError
108
-
109
- pitch_mel = pitch_mel.float()
110
-
111
- if normalize_mean is not None:
112
- assert normalize_std is not None
113
- pitch_mel = normalize_pitch(pitch_mel, normalize_mean, normalize_std)
114
-
115
- return pitch_mel
116
-
117
-
118
- def normalize_pitch(pitch, mean, std):
119
- zeros = (pitch == 0.0)
120
- pitch -= mean[:, None]
121
- pitch /= std[:, None]
122
- pitch[zeros] = 0.0
123
- return pitch
124
-
125
-
126
- class TTSDataset(torch.utils.data.Dataset):
127
- """
128
- 1) loads audio,text pairs
129
- 2) normalizes text and converts them to sequences of one-hot vectors
130
- 3) computes mel-spectrograms from audio files.
131
- """
132
- def __init__(self,
133
- dataset_path,
134
- audiopaths_and_text,
135
- text_cleaners,
136
- n_mel_channels,
137
- symbol_set='english_basic',
138
- p_arpabet=1.0,
139
- n_speakers=1,
140
- load_mel_from_disk=True,
141
- load_pitch_from_disk=True,
142
- pitch_mean=214.72203, # LJSpeech defaults
143
- pitch_std=65.72038,
144
- max_wav_value=None,
145
- sampling_rate=None,
146
- filter_length=None,
147
- hop_length=None,
148
- win_length=None,
149
- mel_fmin=None,
150
- mel_fmax=None,
151
- prepend_space_to_text=False,
152
- append_space_to_text=False,
153
- pitch_online_dir=None,
154
- betabinomial_online_dir=None,
155
- use_betabinomial_interpolator=True,
156
- pitch_online_method='pyin',
157
- **ignored):
158
-
159
- # Expect a list of filenames
160
- if type(audiopaths_and_text) is str:
161
- audiopaths_and_text = [audiopaths_and_text]
162
-
163
- self.dataset_path = dataset_path
164
- self.audiopaths_and_text = load_filepaths_and_text(
165
- dataset_path, audiopaths_and_text,
166
- has_speakers=(n_speakers > 1))
167
- self.load_mel_from_disk = load_mel_from_disk
168
- if not load_mel_from_disk:
169
- self.max_wav_value = max_wav_value
170
- self.sampling_rate = sampling_rate
171
- self.stft = layers.TacotronSTFT(
172
- filter_length, hop_length, win_length,
173
- n_mel_channels, sampling_rate, mel_fmin, mel_fmax)
174
- self.load_pitch_from_disk = load_pitch_from_disk
175
-
176
- self.prepend_space_to_text = prepend_space_to_text
177
- self.append_space_to_text = append_space_to_text
178
-
179
- assert p_arpabet == 0.0 or p_arpabet == 1.0, (
180
- 'Only 0.0 and 1.0 p_arpabet is currently supported. '
181
- 'Variable probability breaks caching of betabinomial matrices.')
182
-
183
- self.tp = TextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet)
184
- self.n_speakers = n_speakers
185
- self.pitch_tmp_dir = pitch_online_dir
186
- self.f0_method = pitch_online_method
187
- self.betabinomial_tmp_dir = betabinomial_online_dir
188
- self.use_betabinomial_interpolator = use_betabinomial_interpolator
189
-
190
- if use_betabinomial_interpolator:
191
- self.betabinomial_interpolator = BetaBinomialInterpolator()
192
-
193
- expected_columns = (2 + int(load_pitch_from_disk) + (n_speakers > 1))
194
-
195
- assert not (load_pitch_from_disk and self.pitch_tmp_dir is not None)
196
-
197
- if len(self.audiopaths_and_text[0]) < expected_columns:
198
- raise ValueError(f'Expected {expected_columns} columns in audiopaths file. '
199
- 'The format is <mel_or_wav>|[<pitch>|]<text>[|<speaker_id>]')
200
-
201
- if len(self.audiopaths_and_text[0]) > expected_columns:
202
- print('WARNING: Audiopaths file has more columns than expected')
203
-
204
- to_tensor = lambda x: torch.Tensor([x]) if type(x) is float else x
205
- self.pitch_mean = to_tensor(pitch_mean)
206
- self.pitch_std = to_tensor(pitch_std)
207
-
208
- def __getitem__(self, index):
209
- # Separate filename and text
210
- if self.n_speakers > 1:
211
- audiopath, *extra, text, speaker = self.audiopaths_and_text[index]
212
- speaker = int(speaker)
213
- else:
214
- audiopath, *extra, text = self.audiopaths_and_text[index]
215
- speaker = None
216
-
217
- mel = self.get_mel(audiopath)
218
- text = self.get_text(text)
219
- # print(text)
220
- pitch = self.get_pitch(index, mel.size(-1))
221
- energy = torch.norm(mel.float(), dim=0, p=2)
222
- attn_prior = self.get_prior(index, mel.shape[1], text.shape[0])
223
-
224
- assert pitch.size(-1) == mel.size(-1)
225
-
226
- # No higher formants?
227
- if len(pitch.size()) == 1:
228
- pitch = pitch[None, :]
229
-
230
-
231
- return (text, mel, len(text), pitch, energy, speaker, attn_prior,
232
- audiopath)
233
-
234
- def __len__(self):
235
- return len(self.audiopaths_and_text)
236
-
237
- def get_mel(self, filename):
238
- if not self.load_mel_from_disk:
239
- audio, sampling_rate = load_wav_to_torch(filename)
240
- if sampling_rate != self.stft.sampling_rate:
241
- raise ValueError("{} SR doesn't match target {} SR".format(
242
- sampling_rate, self.stft.sampling_rate))
243
- audio_norm = audio / self.max_wav_value
244
- audio_norm = audio_norm.unsqueeze(0)
245
- audio_norm = torch.autograd.Variable(audio_norm,
246
- requires_grad=False)
247
- melspec = self.stft.mel_spectrogram(audio_norm)
248
- melspec = torch.squeeze(melspec, 0)
249
- else:
250
- melspec = torch.load(filename)
251
- assert melspec.size(0) == self.stft.n_mel_channels, (
252
- 'Mel dimension mismatch: given {}, expected {}'.format(
253
- melspec.size(0), self.stft.n_mel_channels))
254
-
255
- ################ Plotting mels ########################################
256
- import matplotlib.pyplot as plt
257
- # plt.imshow(melspec.detach().cpu().T,aspect="auto")
258
- fig, ax1 = plt.subplots(ncols=1)
259
- pos = ax1.imshow(melspec.cpu().numpy().T,aspect="auto")
260
- fig.colorbar(pos, ax=ax1)
261
- plt.show()
262
- #######################################################################
263
-
264
- return melspec
265
-
266
- def get_text(self, text):
267
- text = self.tp.encode_text(text)
268
- space = [self.tp.encode_text("A A")[1]]
269
-
270
- if self.prepend_space_to_text:
271
- text = space + text
272
-
273
- if self.append_space_to_text:
274
- text = text + space
275
-
276
- return torch.LongTensor(text)
277
-
278
- def get_prior(self, index, mel_len, text_len):
279
-
280
- if self.use_betabinomial_interpolator:
281
- return torch.from_numpy(self.betabinomial_interpolator(mel_len,
282
- text_len))
283
-
284
- if self.betabinomial_tmp_dir is not None:
285
- audiopath, *_ = self.audiopaths_and_text[index]
286
- fname = Path(audiopath).relative_to(self.dataset_path)
287
- fname = fname.with_suffix('.pt')
288
- cached_fpath = Path(self.betabinomial_tmp_dir, fname)
289
-
290
- if cached_fpath.is_file():
291
- return torch.load(cached_fpath)
292
-
293
- attn_prior = beta_binomial_prior_distribution(text_len, mel_len)
294
-
295
- if self.betabinomial_tmp_dir is not None:
296
- cached_fpath.parent.mkdir(parents=True, exist_ok=True)
297
- torch.save(attn_prior, cached_fpath)
298
-
299
- return attn_prior
300
-
301
- def get_pitch(self, index, mel_len=None):
302
- audiopath, *fields = self.audiopaths_and_text[index]
303
-
304
- if self.n_speakers > 1:
305
- spk = int(fields[-1])
306
- else:
307
- spk = 0
308
-
309
- if self.load_pitch_from_disk:
310
- pitchpath = fields[0]
311
- pitch = torch.load(pitchpath)
312
- if self.pitch_mean is not None:
313
- assert self.pitch_std is not None
314
- pitch = normalize_pitch(pitch, self.pitch_mean, self.pitch_std)
315
- return pitch
316
-
317
- if self.pitch_tmp_dir is not None:
318
- fname = Path(audiopath).relative_to(self.dataset_path)
319
- fname_method = fname.with_suffix('.pt')
320
- cached_fpath = Path(self.pitch_tmp_dir, fname_method)
321
- if cached_fpath.is_file():
322
- return torch.load(cached_fpath)
323
-
324
- # No luck so far - calculate
325
- wav = audiopath
326
- if not wav.endswith('.wav'):
327
- wav = re.sub('/mels/', '/wavs/', wav)
328
- wav = re.sub('.pt$', '.wav', wav)
329
-
330
- pitch_mel = estimate_pitch(wav, mel_len, self.f0_method,
331
- self.pitch_mean, self.pitch_std)
332
-
333
- if self.pitch_tmp_dir is not None and not cached_fpath.is_file():
334
- cached_fpath.parent.mkdir(parents=True, exist_ok=True)
335
- torch.save(pitch_mel, cached_fpath)
336
-
337
- return pitch_mel
338
-
339
-
340
- class TTSCollate:
341
- """Zero-pads model inputs and targets based on number of frames per step"""
342
-
343
- def __call__(self, batch):
344
- """Collate training batch from normalized text and mel-spec"""
345
- # Right zero-pad all one-hot text sequences to max input length
346
- input_lengths, ids_sorted_decreasing = torch.sort(
347
- torch.LongTensor([len(x[0]) for x in batch]),
348
- dim=0, descending=True)
349
- max_input_len = input_lengths[0]
350
-
351
- text_padded = torch.LongTensor(len(batch), max_input_len)
352
- text_padded.zero_()
353
- for i in range(len(ids_sorted_decreasing)):
354
- text = batch[ids_sorted_decreasing[i]][0]
355
- text_padded[i, :text.size(0)] = text
356
-
357
- # Right zero-pad mel-spec
358
- num_mels = batch[0][1].size(0)
359
- max_target_len = max([x[1].size(1) for x in batch])
360
-
361
- # Include mel padded and gate padded
362
- mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
363
- mel_padded.zero_()
364
- output_lengths = torch.LongTensor(len(batch))
365
- for i in range(len(ids_sorted_decreasing)):
366
- mel = batch[ids_sorted_decreasing[i]][1]
367
- mel_padded[i, :, :mel.size(1)] = mel
368
- output_lengths[i] = mel.size(1)
369
-
370
- n_formants = batch[0][3].shape[0]
371
- pitch_padded = torch.zeros(mel_padded.size(0), n_formants,
372
- mel_padded.size(2), dtype=batch[0][3].dtype)
373
- energy_padded = torch.zeros_like(pitch_padded[:, 0, :])
374
-
375
- for i in range(len(ids_sorted_decreasing)):
376
- pitch = batch[ids_sorted_decreasing[i]][3]
377
- energy = batch[ids_sorted_decreasing[i]][4]
378
- pitch_padded[i, :, :pitch.shape[1]] = pitch
379
- energy_padded[i, :energy.shape[0]] = energy
380
-
381
- if batch[0][5] is not None:
382
- speaker = torch.zeros_like(input_lengths)
383
- for i in range(len(ids_sorted_decreasing)):
384
- speaker[i] = batch[ids_sorted_decreasing[i]][5]
385
- else:
386
- speaker = None
387
-
388
- attn_prior_padded = torch.zeros(len(batch), max_target_len,
389
- max_input_len)
390
- attn_prior_padded.zero_()
391
- for i in range(len(ids_sorted_decreasing)):
392
- prior = batch[ids_sorted_decreasing[i]][6]
393
- attn_prior_padded[i, :prior.size(0), :prior.size(1)] = prior
394
-
395
- # Count number of items - characters in text
396
- len_x = [x[2] for x in batch]
397
- len_x = torch.Tensor(len_x)
398
-
399
- audiopaths = [batch[i][7] for i in ids_sorted_decreasing]
400
-
401
- return (text_padded, input_lengths, mel_padded, output_lengths, len_x,
402
- pitch_padded, energy_padded, speaker, attn_prior_padded,
403
- audiopaths)
404
-
405
-
406
- def batch_to_gpu(batch):
407
- (text_padded, input_lengths, mel_padded, output_lengths, len_x,
408
- pitch_padded, energy_padded, speaker, attn_prior, audiopaths) = batch
409
-
410
- text_padded = to_gpu(text_padded).long()
411
- input_lengths = to_gpu(input_lengths).long()
412
- mel_padded = to_gpu(mel_padded).float()
413
- output_lengths = to_gpu(output_lengths).long()
414
- pitch_padded = to_gpu(pitch_padded).float()
415
- energy_padded = to_gpu(energy_padded).float()
416
- attn_prior = to_gpu(attn_prior).float()
417
- if speaker is not None:
418
- speaker = to_gpu(speaker).long()
419
-
420
- # Alignments act as both inputs and targets - pass shallow copies
421
- x = [text_padded, input_lengths, mel_padded, output_lengths,
422
- pitch_padded, energy_padded, speaker, attn_prior, audiopaths]
423
- y = [mel_padded, input_lengths, output_lengths]
424
- len_x = torch.sum(output_lengths)
425
- return (x, y, len_x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fastpitch/data_function_model_py.zip DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7977160d12775529ba3426181093c8bca7927e52024f6d4faa91e1a0e53ef008
3
- size 9564
 
 
 
 
fastpitch/utils_trainplot_transformers.zip DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5d3b72d713798a0552c0939fded67cb89655e12cc406a6fe793fcc7c9f63456a
3
- size 16365
 
 
 
 
fastpitch/utils_trainplot_transformers/train_1_with_plot.py DELETED
@@ -1,591 +0,0 @@
1
- # *****************************************************************************
2
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
- #
4
- # Redistribution and use in source and binary forms, with or without
5
- # modification, are permitted provided that the following conditions are met:
6
- # * Redistributions of source code must retain the above copyright
7
- # notice, this list of conditions and the following disclaimer.
8
- # * Redistributions in binary form must reproduce the above copyright
9
- # notice, this list of conditions and the following disclaimer in the
10
- # documentation and/or other materials provided with the distribution.
11
- # * Neither the name of the NVIDIA CORPORATION nor the
12
- # names of its contributors may be used to endorse or promote products
13
- # derived from this software without specific prior written permission.
14
- #
15
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
- # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
- # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19
- # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
- # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
- # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22
- # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
- # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
- #
26
- # *****************************************************************************
27
-
28
- import argparse
29
- import copy
30
- import os
31
- import time
32
- from collections import defaultdict, OrderedDict
33
- from itertools import cycle
34
-
35
- import numpy as np
36
- import torch
37
- import torch.distributed as dist
38
- import amp_C
39
- from apex.optimizers import FusedAdam, FusedLAMB
40
- from torch.nn.parallel import DistributedDataParallel
41
- from torch.utils.data import DataLoader
42
- from torch.utils.data.distributed import DistributedSampler
43
-
44
- import common.tb_dllogger as logger
45
- import models
46
- from common.tb_dllogger import log
47
- from common.repeated_dataloader import (RepeatedDataLoader,
48
- RepeatedDistributedSampler)
49
- from common.text import cmudict
50
- from common.utils import BenchmarkStats, Checkpointer, prepare_tmp
51
- from fastpitch.attn_loss_function import AttentionBinarizationLoss
52
- from fastpitch.data_function import batch_to_gpu, TTSCollate, TTSDataset
53
- from fastpitch.loss_function import FastPitchLoss
54
-
55
- import matplotlib.pyplot as plt
56
-
57
- def parse_args(parser):
58
- parser.add_argument('-o', '--output', type=str, required=True,
59
- help='Directory to save checkpoints')
60
- parser.add_argument('-d', '--dataset-path', type=str, default='./',
61
- help='Path to dataset')
62
- parser.add_argument('--log-file', type=str, default=None,
63
- help='Path to a DLLogger log file')
64
-
65
- train = parser.add_argument_group('training setup')
66
- train.add_argument('--epochs', type=int, required=True,
67
- help='Number of total epochs to run')
68
- train.add_argument('--epochs-per-checkpoint', type=int, default=50,
69
- help='Number of epochs per checkpoint')
70
- train.add_argument('--checkpoint-path', type=str, default=None,
71
- help='Checkpoint path to resume training')
72
- train.add_argument('--keep-milestones', default=list(range(100, 1000, 100)),
73
- type=int, nargs='+',
74
- help='Milestone checkpoints to keep from removing')
75
- train.add_argument('--resume', action='store_true',
76
- help='Resume training from the last checkpoint')
77
- train.add_argument('--seed', type=int, default=1234,
78
- help='Seed for PyTorch random number generators')
79
- train.add_argument('--amp', action='store_true',
80
- help='Enable AMP')
81
- train.add_argument('--cuda', action='store_true',
82
- help='Run on GPU using CUDA')
83
- train.add_argument('--cudnn-benchmark', action='store_true',
84
- help='Enable cudnn benchmark mode')
85
- train.add_argument('--ema-decay', type=float, default=0,
86
- help='Discounting factor for training weights EMA')
87
- train.add_argument('--grad-accumulation', type=int, default=1,
88
- help='Training steps to accumulate gradients for')
89
- train.add_argument('--kl-loss-start-epoch', type=int, default=250,
90
- help='Start adding the hard attention loss term')
91
- train.add_argument('--kl-loss-warmup-epochs', type=int, default=100,
92
- help='Gradually increase the hard attention loss term')
93
- train.add_argument('--kl-loss-weight', type=float, default=1.0,
94
- help='Gradually increase the hard attention loss term')
95
- train.add_argument('--benchmark-epochs-num', type=int, default=20,
96
- help='Number of epochs for calculating final stats')
97
- train.add_argument('--validation-freq', type=int, default=1,
98
- help='Validate every N epochs to use less compute')
99
-
100
- opt = parser.add_argument_group('optimization setup')
101
- opt.add_argument('--optimizer', type=str, default='lamb',
102
- help='Optimization algorithm')
103
- opt.add_argument('-lr', '--learning-rate', type=float, required=True,
104
- help='Learing rate')
105
- opt.add_argument('--weight-decay', default=1e-6, type=float,
106
- help='Weight decay')
107
- opt.add_argument('--grad-clip-thresh', default=1000.0, type=float,
108
- help='Clip threshold for gradients')
109
- opt.add_argument('-bs', '--batch-size', type=int, required=True,
110
- help='Batch size per GPU')
111
- opt.add_argument('--warmup-steps', type=int, default=1000,
112
- help='Number of steps for lr warmup')
113
- opt.add_argument('--dur-predictor-loss-scale', type=float,
114
- default=1.0, help='Rescale duration predictor loss')
115
- opt.add_argument('--pitch-predictor-loss-scale', type=float,
116
- default=1.0, help='Rescale pitch predictor loss')
117
- opt.add_argument('--attn-loss-scale', type=float,
118
- default=1.0, help='Rescale alignment loss')
119
-
120
- data = parser.add_argument_group('dataset parameters')
121
- data.add_argument('--training-files', type=str, nargs='*', required=True,
122
- help='Paths to training filelists.')
123
- data.add_argument('--validation-files', type=str, nargs='*',
124
- required=True, help='Paths to validation filelists')
125
- data.add_argument('--text-cleaners', nargs='*',
126
- default=['english_cleaners'], type=str,
127
- help='Type of text cleaners for input text')
128
- data.add_argument('--symbol-set', type=str, default='english_basic',
129
- help='Define symbol set for input text')
130
- data.add_argument('--p-arpabet', type=float, default=0.0,
131
- help='Probability of using arpabets instead of graphemes '
132
- 'for each word; set 0 for pure grapheme training')
133
- data.add_argument('--heteronyms-path', type=str, default='cmudict/heteronyms',
134
- help='Path to the list of heteronyms')
135
- data.add_argument('--cmudict-path', type=str, default='cmudict/cmudict-0.7b',
136
- help='Path to the pronouncing dictionary')
137
- data.add_argument('--prepend-space-to-text', action='store_true',
138
- help='Capture leading silence with a space token')
139
- data.add_argument('--append-space-to-text', action='store_true',
140
- help='Capture trailing silence with a space token')
141
- data.add_argument('--num-workers', type=int, default=2, # 6
142
- help='Subprocesses for train and val DataLoaders')
143
- data.add_argument('--trainloader-repeats', type=int, default=100,
144
- help='Repeats the dataset to prolong epochs')
145
-
146
- cond = parser.add_argument_group('data for conditioning')
147
- cond.add_argument('--n-speakers', type=int, default=1,
148
- help='Number of speakers in the dataset. '
149
- 'n_speakers > 1 enables speaker embeddings')
150
- cond.add_argument('--load-pitch-from-disk', action='store_true',
151
- help='Use pitch cached on disk with prepare_dataset.py')
152
- cond.add_argument('--pitch-online-method', default='pyin',
153
- choices=['pyin'],
154
- help='Calculate pitch on the fly during trainig')
155
- cond.add_argument('--pitch-online-dir', type=str, default=None,
156
- help='A directory for storing pitch calculated on-line')
157
- cond.add_argument('--pitch-mean', type=float, default=125.626816, #default=214.72203,
158
- help='Normalization value for pitch')
159
- cond.add_argument('--pitch-std', type=float, default=37.52, #default=65.72038,
160
- help='Normalization value for pitch')
161
- cond.add_argument('--load-mel-from-disk', action='store_true',
162
- help='Use mel-spectrograms cache on the disk') # XXX
163
-
164
- audio = parser.add_argument_group('audio parameters')
165
- audio.add_argument('--max-wav-value', default=32768.0, type=float,
166
- help='Maximum audiowave value')
167
- audio.add_argument('--sampling-rate', default=22050, type=int,
168
- help='Sampling rate')
169
- audio.add_argument('--filter-length', default=1024, type=int,
170
- help='Filter length')
171
- audio.add_argument('--hop-length', default=256, type=int,
172
- help='Hop (stride) length')
173
- audio.add_argument('--win-length', default=1024, type=int,
174
- help='Window length')
175
- audio.add_argument('--mel-fmin', default=0.0, type=float,
176
- help='Minimum mel frequency')
177
- audio.add_argument('--mel-fmax', default=8000.0, type=float,
178
- help='Maximum mel frequency')
179
-
180
- dist = parser.add_argument_group('distributed setup')
181
- dist.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0),
182
- help='Rank of the process for multiproc; do not set manually')
183
- dist.add_argument('--world_size', type=int, default=os.getenv('WORLD_SIZE', 1),
184
- help='Number of processes for multiproc; do not set manually')
185
- return parser
186
-
187
-
188
- def reduce_tensor(tensor, num_gpus):
189
- rt = tensor.clone()
190
- dist.all_reduce(rt, op=dist.ReduceOp.SUM)
191
- return rt.true_divide(num_gpus)
192
-
193
-
194
- def init_distributed(args, world_size, rank):
195
- assert torch.cuda.is_available(), "Distributed mode requires CUDA."
196
- print("Initializing distributed training")
197
-
198
- # Set cuda device so everything is done on the right GPU.
199
- torch.cuda.set_device(rank % torch.cuda.device_count())
200
-
201
- # Initialize distributed communication
202
- dist.init_process_group(backend=('nccl' if args.cuda else 'gloo'),
203
- init_method='env://')
204
- print("Done initializing distributed training")
205
-
206
-
207
- def validate(model, epoch, total_iter, criterion, val_loader, distributed_run,
208
- batch_to_gpu, local_rank, ema=False):
209
- was_training = model.training
210
- model.eval()
211
-
212
- tik = time.perf_counter()
213
- with torch.no_grad():
214
- val_meta = defaultdict(float)
215
- val_num_frames = 0
216
- for i, batch in enumerate(val_loader):
217
- x, y, num_frames = batch_to_gpu(batch)
218
- y_pred = model(x)
219
- loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum')
220
-
221
- if distributed_run:
222
- for k, v in meta.items():
223
- val_meta[k] += reduce_tensor(v, 1)
224
- val_num_frames += reduce_tensor(num_frames.data, 1).item()
225
- else:
226
- for k, v in meta.items():
227
- val_meta[k] += v
228
- val_num_frames += num_frames.item()
229
-
230
- # NOTE: ugly patch to visualize the first utterance of the validation corpus.
231
- # The goal is to determine if the training is progressing properly
232
- if (i == 0) and (local_rank == 0) and (not ema):
233
- # Plot some debug information
234
- fig, axs = plt.subplots(2, 2, figsize=(21,14))
235
-
236
- # - Mel-spectrogram
237
- pred_mel = y_pred[0][0, :, :].cpu().detach().numpy().astype(np.float32).T
238
- orig_mel = y[0][0, :, :].cpu().detach().numpy().astype(np.float32)
239
- axs[0,0].imshow(orig_mel, aspect='auto', origin='lower', interpolation='nearest')
240
- axs[1,0].imshow(pred_mel, aspect='auto', origin='lower', interpolation='nearest')
241
-
242
- # Prosody
243
- f0_pred = y_pred[4][0, :].cpu().detach().numpy().astype(np.float32)
244
- f0_ori = y_pred[5][0, :].cpu().detach().numpy().astype(np.float32)
245
- axs[1,1].plot(f0_ori)
246
- axs[1,1].plot(f0_pred)
247
-
248
- # # Duration
249
- # att_pred = y_pred[2][0, :].cpu().detach().numpy().astype(np.float32)
250
- # att_ori = x[7][0,:].cpu().detach().numpy().astype(np.float32)
251
- # axs[0,1].imshow(att_ori, aspect='auto', origin='lower', interpolation='nearest')
252
-
253
- if not os.path.exists("debug_epoch/"):
254
- os.makedirs("debug_epoch_laila/")
255
-
256
- fig.savefig(f'debug_epoch/{epoch:06d}.png', bbox_inches='tight')
257
-
258
- val_meta = {k: v / len(val_loader.dataset) for k, v in val_meta.items()}
259
-
260
- val_meta['took'] = time.perf_counter() - tik
261
-
262
- log((epoch,) if epoch is not None else (), tb_total_steps=total_iter,
263
- subset='val_ema' if ema else 'val',
264
- data=OrderedDict([
265
- ('loss', val_meta['loss'].item()),
266
- ('mel_loss', val_meta['mel_loss'].item()),
267
- ('frames/s', val_num_frames / val_meta['took']),
268
- ('took', val_meta['took'])]),
269
- )
270
-
271
- if was_training:
272
- model.train()
273
- return val_meta
274
-
275
-
276
- def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None):
277
- if warmup_iters == 0:
278
- scale = 1.0
279
- elif total_iter > warmup_iters:
280
- scale = 1. / (total_iter ** 0.5)
281
- else:
282
- scale = total_iter / (warmup_iters ** 1.5)
283
-
284
- for param_group in opt.param_groups:
285
- param_group['lr'] = learning_rate * scale
286
-
287
-
288
- def apply_ema_decay(model, ema_model, decay):
289
- if not decay:
290
- return
291
- st = model.state_dict()
292
- add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
293
- for k, v in ema_model.state_dict().items():
294
- if add_module and not k.startswith('module.'):
295
- k = 'module.' + k
296
- v.copy_(decay * v + (1 - decay) * st[k])
297
-
298
-
299
- def init_multi_tensor_ema(model, ema_model):
300
- model_weights = list(model.state_dict().values())
301
- ema_model_weights = list(ema_model.state_dict().values())
302
- ema_overflow_buf = torch.cuda.IntTensor([0])
303
- return model_weights, ema_model_weights, ema_overflow_buf
304
-
305
-
306
- def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
307
- amp_C.multi_tensor_axpby(
308
- 65536, overflow_buf, [ema_weights, model_weights, ema_weights],
309
- decay, 1-decay, -1)
310
-
311
-
312
- def main():
313
- parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
314
- allow_abbrev=False)
315
- parser = parse_args(parser)
316
- args, _ = parser.parse_known_args()
317
-
318
- if args.p_arpabet > 0.0:
319
- cmudict.initialize(args.cmudict_path, args.heteronyms_path)
320
-
321
- distributed_run = args.world_size > 1
322
-
323
- torch.manual_seed(args.seed + args.local_rank)
324
- np.random.seed(args.seed + args.local_rank)
325
-
326
- if args.local_rank == 0:
327
- if not os.path.exists(args.output):
328
- os.makedirs(args.output)
329
-
330
- log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
331
- tb_subsets = ['train', 'val']
332
- if args.ema_decay > 0.0:
333
- tb_subsets.append('val_ema')
334
-
335
- logger.init(log_fpath, args.output, enabled=(args.local_rank == 0),
336
- tb_subsets=tb_subsets)
337
- logger.parameters(vars(args), tb_subset='train')
338
-
339
- parser = models.parse_model_args('FastPitch', parser)
340
- args, unk_args = parser.parse_known_args()
341
- if len(unk_args) > 0:
342
- raise ValueError(f'Invalid options {unk_args}')
343
-
344
- torch.backends.cudnn.benchmark = args.cudnn_benchmark
345
-
346
- if distributed_run:
347
- init_distributed(args, args.world_size, args.local_rank)
348
- else:
349
- if args.trainloader_repeats > 1:
350
- print('WARNING: Disabled --trainloader-repeats, supported only for'
351
- ' multi-GPU data loading.')
352
- args.trainloader_repeats = 1
353
-
354
- device = torch.device('cuda' if args.cuda else 'cpu')
355
- model_config = models.get_model_config('FastPitch', args)
356
- model = models.get_model('FastPitch', model_config, device)
357
-
358
- attention_kl_loss = AttentionBinarizationLoss()
359
-
360
- # Store pitch mean/std as params to translate from Hz during inference
361
- model.pitch_mean[0] = args.pitch_mean
362
- model.pitch_std[0] = args.pitch_std
363
-
364
- kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9,
365
- weight_decay=args.weight_decay)
366
- if args.optimizer == 'adam':
367
- optimizer = FusedAdam(model.parameters(), **kw)
368
- # optimizer = torch.optim.Adam(model.parameters(), **kw)
369
- elif args.optimizer == 'lamb':
370
-
371
- optimizer = FusedLAMB(model.parameters(), **kw)
372
- # optimizer = torch.optim.Adam(model.parameters(), **kw)
373
- else:
374
- raise ValueError
375
-
376
- scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
377
-
378
- if args.ema_decay > 0:
379
- ema_model = copy.deepcopy(model)
380
- else:
381
- ema_model = None
382
-
383
- if distributed_run:
384
- model = DistributedDataParallel(
385
- model, device_ids=[args.local_rank], output_device=args.local_rank,
386
- find_unused_parameters=True)
387
-
388
- train_state = {'epoch': 1, 'total_iter': 1}
389
- checkpointer = Checkpointer(args.output, args.keep_milestones)
390
-
391
- checkpointer.maybe_load(model, optimizer, scaler, train_state, args,
392
- ema_model)
393
-
394
- start_epoch = train_state['epoch']
395
- total_iter = train_state['total_iter']
396
-
397
- criterion = FastPitchLoss(
398
- dur_predictor_loss_scale=args.dur_predictor_loss_scale,
399
- pitch_predictor_loss_scale=args.pitch_predictor_loss_scale,
400
- attn_loss_scale=args.attn_loss_scale)
401
-
402
- collate_fn = TTSCollate()
403
-
404
- if args.local_rank == 0:
405
- prepare_tmp(args.pitch_online_dir)
406
-
407
- trainset = TTSDataset(audiopaths_and_text=args.training_files, **vars(args))
408
- valset = TTSDataset(audiopaths_and_text=args.validation_files, **vars(args))
409
-
410
- if distributed_run:
411
- train_sampler = RepeatedDistributedSampler(args.trainloader_repeats,
412
- trainset, drop_last=True)
413
- val_sampler = DistributedSampler(valset)
414
- shuffle = False
415
- else:
416
- train_sampler, val_sampler, shuffle = None, None, False ########### was True
417
-
418
- # 4 workers are optimal on DGX-1 (from epoch 2 onwards)
419
- kw = {'num_workers': args.num_workers, 'batch_size': args.batch_size,
420
- 'collate_fn': collate_fn}
421
- train_loader = RepeatedDataLoader(args.trainloader_repeats, trainset,
422
- shuffle=shuffle, drop_last=True,
423
- sampler=train_sampler, pin_memory=True,
424
- persistent_workers=True, **kw)
425
- val_loader = DataLoader(valset, shuffle=False, sampler=val_sampler,
426
- pin_memory=False, **kw)
427
- if args.ema_decay:
428
- mt_ema_params = init_multi_tensor_ema(model, ema_model)
429
-
430
- model.train()
431
- bmark_stats = BenchmarkStats()
432
-
433
- torch.cuda.synchronize()
434
- for epoch in range(start_epoch, args.epochs + 1):
435
- epoch_start_time = time.perf_counter()
436
-
437
- epoch_loss = 0.0
438
- epoch_mel_loss = 0.0
439
- epoch_num_frames = 0
440
- epoch_frames_per_sec = 0.0
441
-
442
- if distributed_run:
443
- train_loader.sampler.set_epoch(epoch)
444
-
445
- iter_loss = 0
446
- iter_num_frames = 0
447
- iter_meta = {}
448
- iter_start_time = time.perf_counter()
449
-
450
- epoch_iter = 1
451
- for batch, accum_step in zip(train_loader,
452
- cycle(range(1, args.grad_accumulation + 1))):
453
- if accum_step == 1:
454
- adjust_learning_rate(total_iter, optimizer, args.learning_rate,
455
- args.warmup_steps)
456
-
457
- model.zero_grad(set_to_none=True)
458
-
459
- x, y, num_frames = batch_to_gpu(batch)
460
-
461
- with torch.cuda.amp.autocast(enabled=args.amp):
462
- y_pred = model(x)
463
- loss, meta = criterion(y_pred, y)
464
-
465
- if (args.kl_loss_start_epoch is not None
466
- and epoch >= args.kl_loss_start_epoch):
467
-
468
- if args.kl_loss_start_epoch == epoch and epoch_iter == 1:
469
- print('Begin hard_attn loss')
470
-
471
- _, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred
472
- binarization_loss = attention_kl_loss(attn_hard, attn_soft)
473
- kl_weight = min((epoch - args.kl_loss_start_epoch) / args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight
474
- meta['kl_loss'] = binarization_loss.clone().detach() * kl_weight
475
- loss += kl_weight * binarization_loss
476
-
477
- else:
478
- meta['kl_loss'] = torch.zeros_like(loss)
479
- kl_weight = 0
480
- binarization_loss = 0
481
-
482
- loss /= args.grad_accumulation
483
-
484
- meta = {k: v / args.grad_accumulation
485
- for k, v in meta.items()}
486
-
487
- if args.amp:
488
- scaler.scale(loss).backward()
489
- else:
490
- loss.backward()
491
-
492
- if distributed_run:
493
- reduced_loss = reduce_tensor(loss.data, args.world_size).item()
494
- reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
495
- meta = {k: reduce_tensor(v, args.world_size) for k, v in meta.items()}
496
- else:
497
- reduced_loss = loss.item()
498
- reduced_num_frames = num_frames.item()
499
- if np.isnan(reduced_loss):
500
- raise Exception("loss is NaN")
501
-
502
- iter_loss += reduced_loss
503
- iter_num_frames += reduced_num_frames
504
- iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}
505
-
506
- if accum_step % args.grad_accumulation == 0:
507
-
508
- logger.log_grads_tb(total_iter, model)
509
- if args.amp:
510
- scaler.unscale_(optimizer)
511
- torch.nn.utils.clip_grad_norm_(
512
- model.parameters(), args.grad_clip_thresh)
513
- scaler.step(optimizer)
514
- scaler.update()
515
- else:
516
- torch.nn.utils.clip_grad_norm_(
517
- model.parameters(), args.grad_clip_thresh)
518
- optimizer.step()
519
-
520
- if args.ema_decay > 0.0:
521
- apply_multi_tensor_ema(args.ema_decay, *mt_ema_params)
522
-
523
- iter_mel_loss = iter_meta['mel_loss'].item()
524
- iter_kl_loss = iter_meta['kl_loss'].item()
525
- iter_time = time.perf_counter() - iter_start_time
526
- epoch_frames_per_sec += iter_num_frames / iter_time
527
- epoch_loss += iter_loss
528
- epoch_num_frames += iter_num_frames
529
- epoch_mel_loss += iter_mel_loss
530
-
531
- num_iters = len(train_loader) // args.grad_accumulation
532
- log((epoch, epoch_iter, num_iters), tb_total_steps=total_iter,
533
- subset='train', data=OrderedDict([
534
- ('loss', iter_loss),
535
- ('mel_loss', iter_mel_loss),
536
- ('kl_loss', iter_kl_loss),
537
- ('kl_weight', kl_weight),
538
- ('frames/s', iter_num_frames / iter_time),
539
- ('took', iter_time),
540
- ('lrate', optimizer.param_groups[0]['lr'])]),
541
- )
542
-
543
- iter_loss = 0
544
- iter_num_frames = 0
545
- iter_meta = {}
546
- iter_start_time = time.perf_counter()
547
-
548
- if epoch_iter == num_iters:
549
- break
550
- epoch_iter += 1
551
- total_iter += 1
552
-
553
- # Finished epoch
554
- epoch_loss /= epoch_iter
555
- epoch_mel_loss /= epoch_iter
556
- epoch_time = time.perf_counter() - epoch_start_time
557
-
558
- log((epoch,), tb_total_steps=None, subset='train_avg',
559
- data=OrderedDict([
560
- ('loss', epoch_loss),
561
- ('mel_loss', epoch_mel_loss),
562
- ('frames/s', epoch_num_frames / epoch_time),
563
- ('took', epoch_time)]),
564
- )
565
- bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss,
566
- epoch_time)
567
-
568
- if epoch % args.validation_freq == 0:
569
- validate(model, epoch, total_iter, criterion, val_loader,
570
- distributed_run, batch_to_gpu, ema=False, local_rank=args.local_rank)
571
-
572
- if args.ema_decay > 0:
573
- validate(ema_model, epoch, total_iter, criterion, val_loader,
574
- distributed_run, batch_to_gpu, args.local_rank, ema=True)
575
-
576
- # save before making sched.step() for proper loading of LR
577
- checkpointer.maybe_save(args, model, ema_model, optimizer, scaler,
578
- epoch, total_iter, model_config)
579
- logger.flush()
580
-
581
- # Finished training
582
- if len(bmark_stats) > 0:
583
- log((), tb_total_steps=None, subset='train_avg',
584
- data=bmark_stats.get(args.benchmark_epochs_num))
585
-
586
- validate(model, None, total_iter, criterion, val_loader, distributed_run,
587
- batch_to_gpu)
588
-
589
-
590
- if __name__ == '__main__':
591
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fastpitch/utils_trainplot_transformers/transformer.py DELETED
@@ -1,213 +0,0 @@
1
- # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import torch
16
- import torch.nn as nn
17
- import torch.nn.functional as F
18
-
19
- from common.utils import mask_from_lens
20
-
21
-
22
- class PositionalEmbedding(nn.Module):
23
- def __init__(self, demb):
24
- super(PositionalEmbedding, self).__init__()
25
- self.demb = demb
26
- inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
27
- self.register_buffer('inv_freq', inv_freq)
28
-
29
- def forward(self, pos_seq, bsz=None):
30
- sinusoid_inp = torch.matmul(torch.unsqueeze(pos_seq, -1),
31
- torch.unsqueeze(self.inv_freq, 0))
32
- pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
33
- if bsz is not None:
34
- return pos_emb[None, :, :].expand(bsz, -1, -1)
35
- else:
36
- return pos_emb[None, :, :]
37
-
38
-
39
- class PositionwiseConvFF(nn.Module):
40
- def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
41
- super(PositionwiseConvFF, self).__init__()
42
-
43
- self.d_model = d_model
44
- self.d_inner = d_inner
45
- self.dropout = dropout
46
-
47
- self.CoreNet = nn.Sequential(
48
- nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
49
- nn.ReLU(),
50
- # nn.Dropout(dropout), # worse convergence
51
- nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
52
- nn.Dropout(dropout),
53
- )
54
- self.layer_norm = nn.LayerNorm(d_model)
55
- self.pre_lnorm = pre_lnorm
56
-
57
- def forward(self, inp):
58
- return self._forward(inp)
59
-
60
- def _forward(self, inp):
61
- if self.pre_lnorm:
62
- # layer normalization + positionwise feed-forward
63
- core_out = inp.transpose(1, 2)
64
- core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype))
65
- core_out = core_out.transpose(1, 2)
66
-
67
- # residual connection
68
- output = core_out + inp
69
- else:
70
- # positionwise feed-forward
71
- core_out = inp.transpose(1, 2)
72
- core_out = self.CoreNet(core_out)
73
- core_out = core_out.transpose(1, 2)
74
-
75
- # residual connection + layer normalization
76
- output = self.layer_norm(inp + core_out).to(inp.dtype)
77
-
78
- return output
79
-
80
-
81
- class MultiHeadAttn(nn.Module):
82
- def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1,
83
- pre_lnorm=False):
84
- super(MultiHeadAttn, self).__init__()
85
-
86
- self.n_head = n_head
87
- self.d_model = d_model
88
- self.d_head = d_head
89
- self.scale = 1 / (d_head ** 0.5)
90
- self.pre_lnorm = pre_lnorm
91
-
92
- self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
93
- self.drop = nn.Dropout(dropout)
94
- self.dropatt = nn.Dropout(dropatt)
95
- self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
96
- self.layer_norm = nn.LayerNorm(d_model)
97
-
98
- def forward(self, inp, attn_mask=None):
99
- return self._forward(inp, attn_mask)
100
-
101
- def _forward(self, inp, attn_mask=None):
102
- residual = inp
103
-
104
- if self.pre_lnorm:
105
- # layer normalization
106
- inp = self.layer_norm(inp)
107
-
108
- n_head, d_head = self.n_head, self.d_head
109
-
110
- head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)
111
- head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
112
- head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
113
- head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
114
-
115
- q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
116
- k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
117
- v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
118
-
119
- attn_score = torch.bmm(q, k.transpose(1, 2))
120
- attn_score.mul_(self.scale)
121
-
122
- if attn_mask is not None:
123
- attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype)
124
- attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
125
- attn_score.masked_fill_(attn_mask.to(torch.bool), -float('inf'))
126
-
127
- attn_prob = F.softmax(attn_score, dim=2)
128
- attn_prob = self.dropatt(attn_prob)
129
- attn_vec = torch.bmm(attn_prob, v)
130
-
131
- attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
132
- attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(
133
- inp.size(0), inp.size(1), n_head * d_head)
134
-
135
- # linear projection
136
- attn_out = self.o_net(attn_vec)
137
- attn_out = self.drop(attn_out)
138
-
139
- if self.pre_lnorm:
140
- # residual connection
141
- output = residual + attn_out
142
- else:
143
- # residual connection + layer normalization
144
- output = self.layer_norm(residual + attn_out)
145
-
146
- output = output.to(attn_out.dtype)
147
-
148
- return output
149
-
150
-
151
- class TransformerLayer(nn.Module):
152
- def __init__(self, n_head, d_model, d_head, d_inner, kernel_size, dropout,
153
- **kwargs):
154
- super(TransformerLayer, self).__init__()
155
-
156
- self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
157
- self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout,
158
- pre_lnorm=kwargs.get('pre_lnorm'))
159
-
160
- def forward(self, dec_inp, mask=None):
161
- output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
162
- output *= mask
163
- output = self.pos_ff(output)
164
- output *= mask
165
- return output
166
-
167
-
168
- class FFTransformer(nn.Module):
169
- def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
170
- dropout, dropatt, dropemb=0.0, embed_input=True,
171
- n_embed=None, d_embed=None, padding_idx=0, pre_lnorm=False):
172
- super(FFTransformer, self).__init__()
173
- self.d_model = d_model
174
- self.n_head = n_head
175
- self.d_head = d_head
176
- self.padding_idx = padding_idx
177
-
178
- if embed_input:
179
- self.word_emb = nn.Embedding(n_embed, d_embed or d_model,
180
- padding_idx=self.padding_idx)
181
- else:
182
- self.word_emb = None
183
-
184
- self.pos_emb = PositionalEmbedding(self.d_model)
185
- self.drop = nn.Dropout(dropemb)
186
- self.layers = nn.ModuleList()
187
-
188
- for _ in range(n_layer):
189
- self.layers.append(
190
- TransformerLayer(
191
- n_head, d_model, d_head, d_inner, kernel_size, dropout,
192
- dropatt=dropatt, pre_lnorm=pre_lnorm)
193
- )
194
-
195
- def forward(self, dec_inp, seq_lens=None, conditioning=0):
196
- if self.word_emb is None:
197
- inp = dec_inp
198
- mask = mask_from_lens(seq_lens).unsqueeze(2)
199
- else:
200
- inp = self.word_emb(dec_inp)
201
- # [bsz x L x 1]
202
- mask = (dec_inp != self.padding_idx).unsqueeze(2)
203
-
204
- pos_seq = torch.arange(inp.size(1), device=inp.device).to(inp.dtype)
205
- pos_emb = self.pos_emb(pos_seq) * mask
206
-
207
- out = self.drop(inp + pos_emb + conditioning)
208
-
209
- for layer in self.layers:
210
- out = layer(out, mask=mask)
211
-
212
- # out = self.drop(out)
213
- return out, mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fastpitch/utils_trainplot_transformers/transformer_jit.py DELETED
@@ -1,255 +0,0 @@
1
- # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import List, Optional
16
-
17
- import torch
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
-
21
- from common.utils import mask_from_lens
22
-
23
-
24
- class PositionalEmbedding(nn.Module):
25
- def __init__(self, demb):
26
- super(PositionalEmbedding, self).__init__()
27
- self.demb = demb
28
- inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
29
- self.register_buffer('inv_freq', inv_freq)
30
-
31
- def forward(self, pos_seq, bsz: Optional[int] = None):
32
- sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
33
- pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
34
- if bsz is not None:
35
- return pos_emb[None, :, :].expand(bsz, -1, -1)
36
- else:
37
- return pos_emb[None, :, :]
38
-
39
-
40
- class PositionwiseFF(nn.Module):
41
- def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
42
- super(PositionwiseFF, self).__init__()
43
-
44
- self.d_model = d_model
45
- self.d_inner = d_inner
46
- self.dropout = dropout
47
-
48
- self.CoreNet = nn.Sequential(
49
- nn.Linear(d_model, d_inner), nn.ReLU(),
50
- nn.Dropout(dropout),
51
- nn.Linear(d_inner, d_model),
52
- nn.Dropout(dropout),
53
- )
54
-
55
- self.layer_norm = nn.LayerNorm(d_model)
56
- self.pre_lnorm = pre_lnorm
57
-
58
- def forward(self, inp):
59
- if self.pre_lnorm:
60
- # layer normalization + positionwise feed-forward
61
- core_out = self.CoreNet(self.layer_norm(inp))
62
-
63
- # residual connection
64
- output = core_out + inp
65
- else:
66
- # positionwise feed-forward
67
- core_out = self.CoreNet(inp)
68
-
69
- # residual connection + layer normalization
70
- output = self.layer_norm(inp + core_out)
71
-
72
- return output
73
-
74
-
75
- class PositionwiseConvFF(nn.Module):
76
- def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
77
- super(PositionwiseConvFF, self).__init__()
78
-
79
- self.d_model = d_model
80
- self.d_inner = d_inner
81
- self.dropout = dropout
82
-
83
- self.CoreNet = nn.Sequential(
84
- nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
85
- nn.ReLU(),
86
- # nn.Dropout(dropout), # worse convergence
87
- nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
88
- nn.Dropout(dropout),
89
- )
90
- self.layer_norm = nn.LayerNorm(d_model)
91
- self.pre_lnorm = pre_lnorm
92
-
93
- def forward(self, inp):
94
- if self.pre_lnorm:
95
- # layer normalization + positionwise feed-forward
96
- core_out = inp.transpose(1, 2)
97
- core_out = self.CoreNet(self.layer_norm(core_out))
98
- core_out = core_out.transpose(1, 2)
99
-
100
- # residual connection
101
- output = core_out + inp
102
- else:
103
- # positionwise feed-forward
104
- core_out = inp.transpose(1, 2)
105
- core_out = self.CoreNet(core_out)
106
- core_out = core_out.transpose(1, 2)
107
-
108
- # residual connection + layer normalization
109
- output = self.layer_norm(inp + core_out)
110
-
111
- return output
112
-
113
-
114
- class MultiHeadAttn(nn.Module):
115
- def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1,
116
- pre_lnorm=False):
117
- super(MultiHeadAttn, self).__init__()
118
-
119
- self.n_head = n_head
120
- self.d_model = d_model
121
- self.d_head = d_head
122
- self.scale = 1 / (d_head ** 0.5)
123
- self.dropout = dropout
124
- self.pre_lnorm = pre_lnorm
125
-
126
- self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
127
- self.drop = nn.Dropout(dropout)
128
- self.dropatt = nn.Dropout(dropatt)
129
- self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
130
- self.layer_norm = nn.LayerNorm(d_model)
131
-
132
-
133
- def forward(self, inp, attn_mask: Optional[torch.Tensor] = None):
134
- residual = inp
135
-
136
- if self.pre_lnorm:
137
- # layer normalization
138
- inp = self.layer_norm(inp)
139
-
140
- n_head, d_head = self.n_head, self.d_head
141
-
142
- head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=-1)
143
- head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
144
- head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
145
- head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
146
-
147
- q = head_q.permute(0, 2, 1, 3).reshape(-1, inp.size(1), d_head)
148
- k = head_k.permute(0, 2, 1, 3).reshape(-1, inp.size(1), d_head)
149
- v = head_v.permute(0, 2, 1, 3).reshape(-1, inp.size(1), d_head)
150
-
151
- attn_score = torch.bmm(q, k.transpose(1, 2))
152
- attn_score.mul_(self.scale)
153
-
154
- if attn_mask is not None:
155
- attn_mask = attn_mask.unsqueeze(1)
156
- attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
157
- attn_score.masked_fill_(attn_mask, -float('inf'))
158
-
159
- attn_prob = F.softmax(attn_score, dim=2)
160
- attn_prob = self.dropatt(attn_prob)
161
- attn_vec = torch.bmm(attn_prob, v)
162
-
163
- attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
164
- attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(
165
- inp.size(0), inp.size(1), n_head * d_head)
166
-
167
- # linear projection
168
- attn_out = self.o_net(attn_vec)
169
- attn_out = self.drop(attn_out)
170
-
171
- if self.pre_lnorm:
172
- # residual connection
173
- output = residual + attn_out
174
- else:
175
- # residual connection + layer normalization
176
-
177
- # XXX Running TorchScript on 20.02 and 20.03 containers crashes here
178
- # XXX Works well with 20.01-py3 container.
179
- # XXX dirty fix is:
180
- # XXX output = self.layer_norm(residual + attn_out).half()
181
- output = self.layer_norm(residual + attn_out)
182
-
183
- return output
184
-
185
-
186
- class TransformerLayer(nn.Module):
187
- def __init__(self, n_head, d_model, d_head, d_inner, kernel_size, dropout,
188
- **kwargs):
189
- super(TransformerLayer, self).__init__()
190
-
191
- self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
192
- self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout,
193
- pre_lnorm=kwargs.get('pre_lnorm'))
194
-
195
- def forward(self, dec_inp, mask):
196
- output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
197
- output *= mask
198
- output = self.pos_ff(output)
199
- output *= mask
200
- return output
201
-
202
-
203
- class FFTransformer(nn.Module):
204
- def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
205
- dropout, dropatt, dropemb=0.0, embed_input=True,
206
- n_embed=None, d_embed=None, padding_idx=0, pre_lnorm=False):
207
- super(FFTransformer, self).__init__()
208
- self.d_model = d_model
209
- self.n_head = n_head
210
- self.d_head = d_head
211
- self.padding_idx = padding_idx
212
- self.n_embed = n_embed
213
-
214
- self.embed_input = embed_input
215
- if embed_input:
216
- print(padding_idx) #########################################
217
- self.word_emb = nn.Embedding(n_embed, d_embed or d_model,
218
- padding_idx=self.padding_idx)
219
- else:
220
- self.word_emb = nn.Identity()
221
-
222
- self.pos_emb = PositionalEmbedding(self.d_model)
223
- self.drop = nn.Dropout(dropemb)
224
- self.layers = nn.ModuleList()
225
-
226
- for _ in range(n_layer):
227
- self.layers.append(
228
- TransformerLayer(
229
- n_head, d_model, d_head, d_inner, kernel_size, dropout,
230
- dropatt=dropatt, pre_lnorm=pre_lnorm)
231
- )
232
-
233
- def forward(self, dec_inp, seq_lens: Optional[torch.Tensor] = None,
234
- conditioning: Optional[torch.Tensor] = None):
235
- if not self.embed_input:
236
- inp = dec_inp
237
- assert seq_lens is not None
238
- mask = mask_from_lens(seq_lens).unsqueeze(2)
239
- else:
240
- inp = self.word_emb(dec_inp)
241
- # [bsz x L x 1]
242
- mask = (dec_inp != self.padding_idx).unsqueeze(2)
243
-
244
- pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype)
245
- pos_emb = self.pos_emb(pos_seq) * mask
246
- if conditioning is not None:
247
- out = self.drop(inp + pos_emb + conditioning)
248
- else:
249
- out = self.drop(inp + pos_emb)
250
-
251
- for layer in self.layers:
252
- out = layer(out, mask=mask)
253
-
254
- # out = self.drop(out)
255
- return out, mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fastpitch/utils_trainplot_transformers/utils.py DELETED
@@ -1,291 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # MIT License
16
- #
17
- # Copyright (c) 2020 Jungil Kong
18
- #
19
- # Permission is hereby granted, free of charge, to any person obtaining a copy
20
- # of this software and associated documentation files (the "Software"), to deal
21
- # in the Software without restriction, including without limitation the rights
22
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
23
- # copies of the Software, and to permit persons to whom the Software is
24
- # furnished to do so, subject to the following conditions:
25
- #
26
- # The above copyright notice and this permission notice shall be included in all
27
- # copies or substantial portions of the Software.
28
- #
29
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
30
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
31
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
32
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
33
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
34
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35
- # SOFTWARE.
36
-
37
- # The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
38
- # init_weights, get_padding, AttrDict
39
-
40
- import ctypes
41
- import glob
42
- import os
43
- import re
44
- import shutil
45
- import warnings
46
- from collections import defaultdict, OrderedDict
47
- from pathlib import Path
48
- from typing import Optional
49
-
50
- import librosa
51
- import numpy as np
52
-
53
- import torch
54
- import torch.distributed as dist
55
- from scipy.io.wavfile import read
56
-
57
-
58
- def mask_from_lens(lens, max_len: Optional[int] = None):
59
- if max_len is None:
60
- max_len = lens.max()
61
- ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
62
- mask = torch.lt(ids, lens.unsqueeze(1))
63
- return mask
64
-
65
-
66
- def load_wav(full_path, torch_tensor=False):
67
- import soundfile # flac
68
- data, sampling_rate = soundfile.read(full_path, dtype='int16')
69
- if torch_tensor:
70
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
71
- else:
72
- return data, sampling_rate
73
-
74
-
75
- def load_wav_to_torch(full_path, force_sampling_rate=None):
76
- if force_sampling_rate is not None:
77
- data, sampling_rate = librosa.load(full_path, sr=force_sampling_rate)
78
- else:
79
- sampling_rate, data = read(full_path)
80
-
81
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
82
-
83
-
84
- def load_filepaths_and_text(dataset_path, fnames, has_speakers=False, split="|"):
85
- def split_line(root, line):
86
- parts = line.strip().split(split)
87
- if has_speakers:
88
- paths, non_paths = parts[:-2], parts[-2:]
89
- else:
90
- paths, non_paths = parts[:-1], parts[-1:]
91
- return tuple(str(Path(root, p)) for p in paths) + tuple(non_paths)
92
-
93
- fpaths_and_text = []
94
- for fname in fnames:
95
- with open(fname, encoding='utf-8') as f:
96
- fpaths_and_text += [split_line(dataset_path, line) for line in f]
97
- return fpaths_and_text
98
-
99
-
100
- def to_gpu(x):
101
- x = x.contiguous()
102
- return x.cuda(non_blocking=True) if torch.cuda.is_available() else x
103
-
104
-
105
- def l2_promote():
106
- _libcudart = ctypes.CDLL('libcudart.so')
107
- # Set device limit on the current device
108
- # cudaLimitMaxL2FetchGranularity = 0x05
109
- pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))
110
- _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
111
- _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
112
- assert pValue.contents.value == 128
113
-
114
-
115
- def prepare_tmp(path):
116
- if path is None:
117
- return
118
- p = Path(path)
119
- if p.is_dir():
120
- warnings.warn(f'{p} exists. Removing...')
121
- shutil.rmtree(p, ignore_errors=True)
122
- p.mkdir(parents=False, exist_ok=False)
123
-
124
-
125
- def print_once(*msg):
126
- if not dist.is_initialized() or dist.get_rank() == 0:
127
- print(*msg)
128
-
129
-
130
- def init_weights(m, mean=0.0, std=0.01):
131
- classname = m.__class__.__name__
132
- if classname.find("Conv") != -1:
133
- m.weight.data.normal_(mean, std)
134
-
135
-
136
- def get_padding(kernel_size, dilation=1):
137
- return int((kernel_size*dilation - dilation)/2)
138
-
139
-
140
- class AttrDict(dict):
141
- def __init__(self, *args, **kwargs):
142
- super(AttrDict, self).__init__(*args, **kwargs)
143
- self.__dict__ = self
144
-
145
-
146
- class DefaultAttrDict(defaultdict):
147
- def __init__(self, *args, **kwargs):
148
- super(DefaultAttrDict, self).__init__(*args, **kwargs)
149
- self.__dict__ = self
150
-
151
- def __getattr__(self, item):
152
- return self[item]
153
-
154
-
155
- class BenchmarkStats:
156
- """ Tracks statistics used for benchmarking. """
157
- def __init__(self):
158
- self.num_frames = []
159
- self.losses = []
160
- self.mel_losses = []
161
- self.took = []
162
-
163
- def update(self, num_frames, losses, mel_losses, took):
164
- self.num_frames.append(num_frames)
165
- self.losses.append(losses)
166
- self.mel_losses.append(mel_losses)
167
- self.took.append(took)
168
-
169
- def get(self, n_epochs):
170
- frames_s = sum(self.num_frames[-n_epochs:]) / sum(self.took[-n_epochs:])
171
- return {'frames/s': frames_s,
172
- 'loss': np.mean(self.losses[-n_epochs:]),
173
- 'mel_loss': np.mean(self.mel_losses[-n_epochs:]),
174
- 'took': np.mean(self.took[-n_epochs:]),
175
- 'benchmark_epochs_num': n_epochs}
176
-
177
- def __len__(self):
178
- return len(self.losses)
179
-
180
-
181
- class Checkpointer:
182
-
183
- def __init__(self, save_dir, keep_milestones=[]):
184
- self.save_dir = save_dir
185
- self.keep_milestones = keep_milestones
186
-
187
- find = lambda name: [
188
- (int(re.search("_(\d+).pt", fn).group(1)), fn)
189
- for fn in glob.glob(f"{save_dir}/{name}_checkpoint_*.pt")]
190
-
191
- tracked = sorted(find("FastPitch"), key=lambda t: t[0])
192
- self.tracked = OrderedDict(tracked)
193
-
194
- def last_checkpoint(self, output):
195
-
196
- def corrupted(fpath):
197
- try:
198
- torch.load(fpath, map_location="cpu")
199
- return False
200
- except:
201
- warnings.warn(f"Cannot load {fpath}")
202
- return True
203
-
204
- saved = sorted(
205
- glob.glob(f"{output}/FastPitch_checkpoint_*.pt"),
206
- key=lambda f: int(re.search("_(\d+).pt", f).group(1)))
207
-
208
- if len(saved) >= 1 and not corrupted(saved[-1]):
209
- return saved[-1]
210
- elif len(saved) >= 2:
211
- return saved[-2]
212
- else:
213
- return None
214
-
215
- def maybe_load(self, model, optimizer, scaler, train_state, args,
216
- ema_model=None):
217
-
218
- assert args.checkpoint_path is None or args.resume is False, (
219
- "Specify a single checkpoint source")
220
-
221
- fpath = None
222
- if args.checkpoint_path is not None:
223
- fpath = args.checkpoint_path
224
- self.tracked = OrderedDict() # Do not track/delete prev ckpts
225
- elif args.resume:
226
- fpath = self.last_checkpoint(args.output)
227
-
228
- if fpath is None:
229
- return
230
-
231
- print_once(f"Loading model and optimizer state from {fpath}")
232
- ckpt = torch.load(fpath, map_location="cpu")
233
- train_state["epoch"] = ckpt["epoch"] + 1
234
- train_state["total_iter"] = ckpt["iteration"]
235
-
236
- no_pref = lambda sd: {re.sub("^module.", "", k): v for k, v in sd.items()}
237
- unwrap = lambda m: getattr(m, "module", m)
238
-
239
- unwrap(model).load_state_dict(no_pref(ckpt["state_dict"]))
240
-
241
- if ema_model is not None:
242
- unwrap(ema_model).load_state_dict(no_pref(ckpt["ema_state_dict"]))
243
-
244
- optimizer.load_state_dict(ckpt["optimizer"])
245
-
246
- if "scaler" in ckpt:
247
- scaler.load_state_dict(ckpt["scaler"])
248
- else:
249
- warnings.warn("AMP scaler state missing from the checkpoint.")
250
-
251
- def maybe_save(self, args, model, ema_model, optimizer, scaler, epoch,
252
- total_iter, config):
253
-
254
- intermediate = (args.epochs_per_checkpoint > 0
255
- and epoch % args.epochs_per_checkpoint == 0)
256
- final = epoch == args.epochs
257
-
258
- if not intermediate and not final and epoch not in self.keep_milestones:
259
- return
260
-
261
- rank = 0
262
- if dist.is_initialized():
263
- dist.barrier()
264
- rank = dist.get_rank()
265
-
266
- if rank != 0:
267
- return
268
-
269
- unwrap = lambda m: getattr(m, "module", m)
270
- ckpt = {"epoch": epoch,
271
- "iteration": total_iter,
272
- "config": config,
273
- "train_setup": args.__dict__,
274
- "state_dict": unwrap(model).state_dict(),
275
- "optimizer": optimizer.state_dict(),
276
- "scaler": scaler.state_dict()}
277
- if ema_model is not None:
278
- ckpt["ema_state_dict"] = unwrap(ema_model).state_dict()
279
-
280
- fpath = Path(args.output, f"FastPitch_checkpoint_{epoch}.pt")
281
- print(f"Saving model and optimizer state at epoch {epoch} to {fpath}")
282
- torch.save(ckpt, fpath)
283
-
284
- # Remove old checkpoints; keep milestones and the last two
285
- self.tracked[epoch] = fpath
286
- for epoch in set(list(self.tracked)[:-2]) - set(self.keep_milestones):
287
- try:
288
- os.remove(self.tracked[epoch])
289
- except:
290
- pass
291
- del self.tracked[epoch]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_gui.py DELETED
@@ -1,74 +0,0 @@
1
- import gradio as gr
2
- import syn_hifigan as syn
3
- #import syn_k_univnet_multi as syn
4
- import os, tempfile
5
-
6
- languages = {"South Sámi":0,
7
- "North Sámi":1,
8
- "Lule Sámi":2}
9
-
10
- speakers={"aj0": 0,
11
- "aj1": 1,
12
- "am": 2,
13
- "bi": 3,
14
- "kd": 4,
15
- "ln": 5,
16
- "lo": 6,
17
- "ms": 7,
18
- "mu": 8,
19
- "sa": 9
20
- }
21
- public=False
22
-
23
- tempdir = tempfile.gettempdir()
24
-
25
- tts = syn.Synthesizer()
26
-
27
-
28
-
29
- def speak(text, language,speaker,l_weight, s_weight, pace, postfilter): #pitch_shift,pitch_std):
30
-
31
-
32
-
33
- # text frontend not implemented...
34
- text = text.replace("...", "…")
35
- print(speakers[speaker])
36
- audio = tts.speak(text, output_file=f'{tempdir}/tmp', lang=languages[language],
37
- spkr=speakers[speaker], l_weight=l_weight, s_weight=s_weight,
38
- pace=pace, clarity=postfilter)
39
-
40
- if not public:
41
- try:
42
- os.system("play "+tempdir+"/tmp.wav &")
43
- except:
44
- pass
45
-
46
- return (22050, audio)
47
-
48
-
49
-
50
- controls = []
51
- controls.append(gr.Textbox(label="text", value="Suohtas duinna deaivvadit."))
52
- controls.append(gr.Dropdown(list(languages.keys()), label="language", value="North Sámi"))
53
- controls.append(gr.Dropdown(list(speakers.keys()), label="speaker", value="ms"))
54
-
55
- controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="language weight"))
56
- controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="speaker weight"))
57
-
58
- controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1.0, label="speech rate"))
59
- controls.append(gr.Slider(minimum=0., maximum=2, step=0.05, value=1.0, label="post-processing"))
60
-
61
-
62
-
63
-
64
- tts_gui = gr.Interface(
65
- fn=speak,
66
- inputs=controls,
67
- outputs= gr.Audio(label="output"),
68
- live=False
69
-
70
- )
71
-
72
-
73
- if __name__ == "__main__":
74
- tts_gui.launch(share=public)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_gui_katri.py DELETED
@@ -1,73 +0,0 @@
1
- import gradio as gr
2
- #import syn_hifigan as syn
3
- import syn_k_univnet_multi as syn
4
- import os, tempfile
5
-
6
- languages = {"South Sámi":0,
7
- "North Sámi":1,
8
- "Lule Sámi":2}
9
-
10
- speakers={"aj0": 0,
11
- "aj1": 1,
12
- "am": 2,
13
- "bi": 3,
14
- "kd": 4,
15
- "ln": 5,
16
- "lo": 6,
17
- "ms": 7,
18
- "mu": 8,
19
- "sa": 9
20
- }
21
- public=True
22
-
23
- tempdir = tempfile.gettempdir()
24
-
25
- tts = syn.Synthesizer()
26
-
27
-
28
-
29
- def speak(text, language,speaker,l_weight, s_weight, pace): #pitch_shift,pitch_std):
30
-
31
-
32
-
33
- # text frontend not implemented...
34
- text = text.replace("...", "…")
35
- print(speakers[speaker])
36
- audio = tts.speak(text, output_file=f'{tempdir}/tmp', lang=languages[language],
37
- spkr=speakers[speaker], l_weight=l_weight, s_weight=s_weight,
38
- pace=pace)
39
-
40
- if not public:
41
- try:
42
- os.system("play "+tempdir+"/tmp.wav &")
43
- except:
44
- pass
45
-
46
- return (22050, audio)
47
-
48
-
49
-
50
- controls = []
51
- controls.append(gr.Textbox(label="text", value="Suohtas duinna deaivvadit."))
52
- controls.append(gr.Dropdown(list(languages.keys()), label="language", value="North Sámi"))
53
- controls.append(gr.Dropdown(list(speakers.keys()), label="speaker", value="ms"))
54
-
55
- controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="Language weight"))
56
- controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1, label="Speaker weight"))
57
- #controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1.0, label="Pitch variance"))
58
- controls.append(gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1.0, label="speech rate"))
59
-
60
-
61
-
62
-
63
- tts_gui = gr.Interface(
64
- fn=speak,
65
- inputs=controls,
66
- outputs= gr.Audio(label="output"),
67
- live=False
68
-
69
- )
70
-
71
-
72
- if __name__ == "__main__":
73
- tts_gui.launch(share=public)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prepare_dataset.py DELETED
@@ -1,180 +0,0 @@
1
- # *****************************************************************************
2
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
- #
4
- # Redistribution and use in source and binary forms, with or without
5
- # modification, are permitted provided that the following conditions are met:
6
- # * Redistributions of source code must retain the above copyright
7
- # notice, this list of conditions and the following disclaimer.
8
- # * Redistributions in binary form must reproduce the above copyright
9
- # notice, this list of conditions and the following disclaimer in the
10
- # documentation and/or other materials provided with the distribution.
11
- # * Neither the name of the NVIDIA CORPORATION nor the
12
- # names of its contributors may be used to endorse or promote products
13
- # derived from this software without specific prior written permission.
14
- #
15
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
- # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
- # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19
- # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
- # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
- # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22
- # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
- # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
- #
26
- # *****************************************************************************
27
-
28
- import argparse
29
- import time
30
- from pathlib import Path
31
-
32
- import torch
33
- import tqdm
34
- import dllogger as DLLogger
35
- from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
36
- from torch.utils.data import DataLoader
37
-
38
- from fastpitch.data_function import TTSCollate, TTSDataset
39
-
40
-
41
- def parse_args(parser):
42
- """
43
- Parse commandline arguments.
44
- """
45
- parser.add_argument('-d', '--dataset-path', type=str,
46
- default='./', help='Path to dataset')
47
- parser.add_argument('--wav-text-filelists', required=True, nargs='+',
48
- type=str, help='Files with audio paths and text')
49
- parser.add_argument('--extract-mels', action='store_true',
50
- help='Calculate spectrograms from .wav files')
51
- parser.add_argument('--extract-pitch', action='store_true',
52
- help='Extract pitch')
53
- parser.add_argument('--save-alignment-priors', action='store_true',
54
- help='Pre-calculate diagonal matrices of alignment of text to audio')
55
- parser.add_argument('--log-file', type=str, default='preproc_log.json',
56
- help='Filename for logging')
57
- parser.add_argument('--n-speakers', type=int, default=1)
58
- parser.add_argument('--n-languages', type=int, default=1)
59
- # Mel extraction
60
- parser.add_argument('--max-wav-value', default=32768.0, type=float,
61
- help='Maximum audiowave value')
62
- parser.add_argument('--sampling-rate', default=22050, type=int,
63
- help='Sampling rate')
64
- parser.add_argument('--filter-length', default=1024, type=int,
65
- help='Filter length')
66
- parser.add_argument('--hop-length', default=256, type=int,
67
- help='Hop (stride) length')
68
- parser.add_argument('--win-length', default=1024, type=int,
69
- help='Window length')
70
- parser.add_argument('--mel-fmin', default=0.0, type=float,
71
- help='Minimum mel frequency')
72
- parser.add_argument('--mel-fmax', default=8000.0, type=float,
73
- help='Maximum mel frequency')
74
- parser.add_argument('--n-mel-channels', type=int, default=80)
75
- # Pitch extraction
76
- parser.add_argument('--f0-method', default='pyin', type=str,
77
- choices=['pyin'], help='F0 estimation method')
78
- parser.add_argument('--pitch-mean', default='214', type=float, ###
79
- help='F0 estimation method')
80
- parser.add_argument('--pitch-std', default='65', type=float, ####
81
- help='F0 estimation method')
82
- # Performance
83
- parser.add_argument('-b', '--batch-size', default=1, type=int)
84
- parser.add_argument('--n-workers', type=int, default=16)
85
- return parser
86
-
87
-
88
- def main():
89
- parser = argparse.ArgumentParser(description='FastPitch Data Pre-processing')
90
- parser = parse_args(parser)
91
- args, unk_args = parser.parse_known_args()
92
- if len(unk_args) > 0:
93
- raise ValueError(f'Invalid options {unk_args}')
94
-
95
- DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, Path(args.dataset_path, args.log_file)),
96
- StdOutBackend(Verbosity.VERBOSE)])
97
- for k, v in vars(args).items():
98
- DLLogger.log(step="PARAMETER", data={k: v})
99
- DLLogger.flush()
100
-
101
- if args.extract_mels:
102
- Path(args.dataset_path, 'mels').mkdir(parents=False, exist_ok=True)
103
-
104
- if args.extract_pitch:
105
- Path(args.dataset_path, 'pitch').mkdir(parents=False, exist_ok=True)
106
-
107
- if args.save_alignment_priors:
108
- Path(args.dataset_path, 'alignment_priors').mkdir(parents=False, exist_ok=True)
109
-
110
- for filelist in args.wav_text_filelists:
111
-
112
- print(f'Processing {filelist}...')
113
-
114
- dataset = TTSDataset(
115
- args.dataset_path,
116
- filelist,
117
- text_cleaners=['basic_cleaners'],
118
- n_mel_channels=args.n_mel_channels,
119
- p_arpabet=0.0,
120
- n_speakers=args.n_speakers,
121
- n_languages=args.n_languages,
122
- load_mel_from_disk=False,
123
- load_pitch_from_disk=False,
124
- pitch_mean=args.pitch_mean,
125
- pitch_std=args.pitch_std,
126
- max_wav_value=args.max_wav_value,
127
- sampling_rate=args.sampling_rate,
128
- filter_length=args.filter_length,
129
- hop_length=args.hop_length,
130
- win_length=args.win_length,
131
- mel_fmin=args.mel_fmin,
132
- mel_fmax=args.mel_fmax,
133
- betabinomial_online_dir=None,
134
- pitch_online_dir=None,
135
- pitch_online_method=args.f0_method)
136
-
137
- data_loader = DataLoader(
138
- dataset,
139
- batch_size=args.batch_size,
140
- shuffle=False,
141
- sampler=None,
142
- num_workers=args.n_workers,
143
- collate_fn=TTSCollate(),
144
- pin_memory=False,
145
- drop_last=False)
146
-
147
- all_filenames = set()
148
- for i, batch in enumerate(tqdm.tqdm(data_loader)):
149
- tik = time.time()
150
-
151
- _, input_lens, mels, mel_lens, _, pitch, _, _, _, attn_prior, fpaths = batch
152
-
153
- # Ensure filenames are unique
154
- for p in fpaths:
155
- fname = Path(p).name
156
- if fname in all_filenames:
157
- raise ValueError(f'Filename is not unique: {fname}')
158
- all_filenames.add(fname)
159
-
160
- if args.extract_mels:
161
- for j, mel in enumerate(mels):
162
- fname = Path(fpaths[j]).with_suffix('.pt').name
163
- fpath = Path(args.dataset_path, 'mels', fname)
164
- torch.save(mel[:, :mel_lens[j]], fpath)
165
-
166
- if args.extract_pitch:
167
- for j, p in enumerate(pitch):
168
- fname = Path(fpaths[j]).with_suffix('.pt').name
169
- fpath = Path(args.dataset_path, 'pitch', fname)
170
- torch.save(p[:mel_lens[j]], fpath)
171
-
172
- if args.save_alignment_priors:
173
- for j, prior in enumerate(attn_prior):
174
- fname = Path(fpaths[j]).with_suffix('.pt').name
175
- fpath = Path(args.dataset_path, 'alignment_priors', fname)
176
- torch.save(prior[:mel_lens[j], :input_lens[j]], fpath)
177
-
178
-
179
- if __name__ == '__main__':
180
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_training_cluster_s.sh DELETED
@@ -1,33 +0,0 @@
1
- #!/bin/bash
2
- #SBATCH --job-name=train_fastpitch
3
- #SBATCH --account=nn9866k
4
- #SBATCH --time=11:50:00
5
- #SBATCH --mem=16G
6
- #SBATCH --partition=accel
7
- #SBATCH --gres=gpu:1
8
-
9
- # == Logging
10
-
11
- #SBATCH --error="log_err" # Save the error messages
12
- #SBATCH --output="log_out" # Save the stdout
13
-
14
- ## Set up job environment:
15
- # set -o errexit # Exit the script on any error
16
- # set -o nounset # Treat any unset variables as an error
17
-
18
- ## Activate environment
19
- # source ~/.bashrc
20
-
21
- eval "$(conda shell.bash hook)"
22
- conda activate fastpitch
23
-
24
- # Setup monitoring
25
- nvidia-smi --query-gpu=timestamp,utilization.gpu,utilization.memory \
26
- --format=csv --loop=1 > "gpu_util-$SLURM_JOB_ID.csv" &
27
- NVIDIA_MONITOR_PID=$! # Capture PID of monitoring process
28
-
29
- # Run our computation
30
- bash scripts/train_2.sh
31
-
32
- # After computation stop monitoring
33
- kill -SIGINT "$NVIDIA_MONITOR_PID"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/docker/build.sh DELETED
@@ -1,3 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- docker build . -t fastpitch:latest
 
 
 
 
scripts/docker/interactive.sh DELETED
@@ -1,5 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- PORT=${PORT:-8888}
4
-
5
- docker run --gpus=all -it --rm -e CUDA_VISIBLE_DEVICES --ipc=host -p $PORT:$PORT -v $PWD:/workspace/fastpitch/ fastpitch:latest bash
 
 
 
 
 
 
scripts/download_cmudict.sh DELETED
@@ -1,10 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- set -e
4
-
5
- : ${CMUDICT_DIR:="cmudict"}
6
-
7
- if [ ! -f $CMUDICT_DIR/cmudict-0.7b ]; then
8
- echo "Downloading cmudict-0.7b ..."
9
- wget https://github.com/Alexir/CMUdict/raw/master/cmudict-0.7b -qO $CMUDICT_DIR/cmudict-0.7b
10
- fi
 
 
 
 
 
 
 
 
 
 
 
scripts/download_dataset.sh DELETED
@@ -1,17 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- set -e
4
-
5
- scripts/download_cmudict.sh
6
-
7
- DATA_DIR="LJSpeech-1.1"
8
- LJS_ARCH="LJSpeech-1.1.tar.bz2"
9
- LJS_URL="http://data.keithito.com/data/speech/${LJS_ARCH}"
10
-
11
- if [ ! -d ${DATA_DIR} ]; then
12
- echo "Downloading ${LJS_ARCH} ..."
13
- wget -q ${LJS_URL}
14
- echo "Extracting ${LJS_ARCH} ..."
15
- tar jxvf ${LJS_ARCH}
16
- rm -f ${LJS_ARCH}
17
- fi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/download_models.sh DELETED
@@ -1,63 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- set -e
4
-
5
- MODEL_NAMES="$@"
6
- [ -z "$MODEL_NAMES" ] && { echo "Usage: $0 [fastpitch|waveglow|hifigan|hifigan-finetuned-fastpitch]"; exit 1; }
7
-
8
- function download_ngc_model() {
9
- mkdir -p "$MODEL_DIR"
10
-
11
- if [ ! -f "${MODEL_DIR}/${MODEL_ZIP}" ]; then
12
- echo "Downloading ${MODEL_ZIP} ..."
13
- wget --content-disposition -O ${MODEL_DIR}/${MODEL_ZIP} ${MODEL_URL} \
14
- || { echo "ERROR: Failed to download ${MODEL_ZIP} from NGC"; exit 1; }
15
- fi
16
-
17
- if [ ! -f "${MODEL_DIR}/${MODEL}" ]; then
18
- echo "Extracting ${MODEL} ..."
19
- unzip -qo ${MODEL_DIR}/${MODEL_ZIP} -d ${MODEL_DIR} \
20
- || { echo "ERROR: Failed to extract ${MODEL_ZIP}"; exit 1; }
21
-
22
- echo "OK"
23
-
24
- else
25
- echo "${MODEL} already downloaded."
26
- fi
27
-
28
- }
29
-
30
- for MODEL_NAME in $MODEL_NAMES
31
- do
32
- case $MODEL_NAME in
33
- "fastpitch")
34
- MODEL_DIR="pretrained_models/fastpitch"
35
- MODEL_ZIP="fastpitch_pyt_fp32_ckpt_v1_1_21.05.0.zip"
36
- MODEL="nvidia_fastpitch_210824.pt"
37
- MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/fastpitch_pyt_fp32_ckpt_v1_1/versions/21.05.0/zip"
38
- ;;
39
- "hifigan")
40
- MODEL_DIR="pretrained_models/hifigan"
41
- MODEL_ZIP="hifigan__pyt_ckpt_ds-ljs22khz_21.08.0_amp.zip"
42
- MODEL="hifigan_gen_checkpoint_6500.pt"
43
- MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/dle/hifigan__pyt_ckpt_ds-ljs22khz/versions/21.08.0_amp/zip"
44
- ;;
45
- "hifigan-finetuned-fastpitch")
46
- MODEL_DIR="pretrained_models/hifigan"
47
- MODEL_ZIP="hifigan__pyt_ckpt_mode-finetune_ds-ljs22khz_21.08.0_amp.zip"
48
- MODEL="hifigan_gen_checkpoint_10000_ft.pt"
49
- MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/dle/hifigan__pyt_ckpt_mode-finetune_ds-ljs22khz/versions/21.08.0_amp/zip"
50
- ;;
51
- "waveglow")
52
- MODEL_DIR="pretrained_models/waveglow"
53
- MODEL_ZIP="waveglow_ckpt_amp_256_20.01.0.zip"
54
- MODEL="nvidia_waveglow256pyt_fp16.pt"
55
- MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/waveglow_ckpt_amp_256/versions/20.01.0/zip"
56
- ;;
57
- *)
58
- echo "Unrecognized model: ${MODEL_NAME}"
59
- exit 2
60
- ;;
61
- esac
62
- download_ngc_model "$MODEL_DIR" "$MODEL_ZIP" "$MODEL" "$MODEL_URL"
63
- done
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/inference_benchmark.sh DELETED
@@ -1,16 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- set -a
4
-
5
- : ${FILELIST:="phrases/benchmark_8_128.tsv"}
6
- : ${OUTPUT_DIR:="./output/audio_$(basename ${FILELIST} .tsv)"}
7
- : ${TORCHSCRIPT:=true}
8
- : ${BS_SEQUENCE:="1 4 8"}
9
- : ${WARMUP:=64}
10
- : ${REPEATS:=500}
11
- : ${AMP:=false}
12
-
13
- for BATCH_SIZE in $BS_SEQUENCE ; do
14
- LOG_FILE="$OUTPUT_DIR"/perf-infer_amp-${AMP}_bs${BATCH_SIZE}.json
15
- bash scripts/inference_example.sh "$@"
16
- done
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/inference_example.sh DELETED
@@ -1,78 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- export CUDNN_V8_API_ENABLED=1 # Keep the flag for older containers
4
- export TORCH_CUDNN_V8_API_ENABLED=1
5
-
6
- : ${DATASET_DIR:="sander_splits"}
7
- : ${BATCH_SIZE:=1}
8
- : ${FILELIST:="phrases/giehttjit.txt"}
9
- : ${AMP:=false}
10
- : ${TORCHSCRIPT:=true}
11
- : ${WARMUP:=0}
12
- : ${REPEATS:=1}
13
- : ${CPU:=false}
14
- : ${PHONE:=true}
15
-
16
- # Paths to pre-trained models downloadable from NVIDIA NGC (LJSpeech-1.1)
17
- FASTPITCH_LJ="output/FastPitch_checkpoint_660.pt"
18
- HIFIGAN_LJ="pretrained_models/hifigan/hifigan_gen_checkpoint_10000_ft.pt"
19
- WAVEGLOW_LJ="pretrained_models/waveglow/nvidia_waveglow256pyt_fp16.pt"
20
-
21
- # Mel-spectrogram generator (optional; can synthesize from ground-truth spectrograms)
22
- : ${FASTPITCH=$FASTPITCH_LJ}
23
-
24
- # Vocoder (set only one)
25
- #: ${HIFIGAN=$HIFIGAN_LJ}
26
- : ${WAVEGLOW=$WAVEGLOW_LJ}
27
-
28
- [[ "$FASTPITCH" == "$FASTPITCH_LJ" && ! -f "$FASTPITCH" ]] && { echo "Downloading $FASTPITCH from NGC..."; bash scripts/download_models.sh fastpitch; }
29
- [[ "$WAVEGLOW" == "$WAVEGLOW_LJ" && ! -f "$WAVEGLOW" ]] && { echo "Downloading $WAVEGLOW from NGC..."; bash scripts/download_models.sh waveglow; }
30
- [[ "$HIFIGAN" == "$HIFIGAN_LJ" && ! -f "$HIFIGAN" ]] && { echo "Downloading $HIFIGAN from NGC..."; bash scripts/download_models.sh hifigan-finetuned-fastpitch; }
31
-
32
- if [[ "$HIFIGAN" == "$HIFIGAN_LJ" && "$FASTPITCH" != "$FASTPITCH_LJ" ]]; then
33
- echo -e "\nNOTE: Using HiFi-GAN checkpoint trained for the LJSpeech-1.1 dataset."
34
- echo -e "NOTE: If you're using a different dataset, consider training a new HiFi-GAN model or switch to WaveGlow."
35
- echo -e "NOTE: See $0 for details.\n"
36
- fi
37
-
38
- # Synthesis
39
- : ${SPEAKER:=0}
40
- : ${DENOISING:=0.01}
41
-
42
- if [ ! -n "$OUTPUT_DIR" ]; then
43
- OUTPUT_DIR="./output/audio_$(basename ${FILELIST} .tsv)"
44
- [ "$AMP" = true ] && OUTPUT_DIR+="_fp16"
45
- [ "$AMP" = false ] && OUTPUT_DIR+="_fp32"
46
- [ -n "$FASTPITCH" ] && OUTPUT_DIR+="_fastpitch"
47
- [ ! -n "$FASTPITCH" ] && OUTPUT_DIR+="_gt-mel"
48
- [ -n "$WAVEGLOW" ] && OUTPUT_DIR+="_waveglow"
49
- [ -n "$HIFIGAN" ] && OUTPUT_DIR+="_hifigan"
50
- OUTPUT_DIR+="_denoise-"${DENOISING}
51
- fi
52
- : ${LOG_FILE:="$OUTPUT_DIR/nvlog_infer.json"}
53
- mkdir -p "$OUTPUT_DIR"
54
-
55
- echo -e "\nAMP=$AMP, batch_size=$BATCH_SIZE\n"
56
-
57
- ARGS=""
58
- ARGS+=" --cuda"
59
- # ARGS+=" --cudnn-benchmark" # Enable for benchmarking or long operation
60
- ARGS+=" --dataset-path $DATASET_DIR"
61
- ARGS+=" -i $FILELIST"
62
- ARGS+=" -o $OUTPUT_DIR"
63
- ARGS+=" --log-file $LOG_FILE"
64
- ARGS+=" --batch-size $BATCH_SIZE"
65
- ARGS+=" --denoising-strength $DENOISING"
66
- ARGS+=" --warmup-steps $WARMUP"
67
- ARGS+=" --repeats $REPEATS"
68
- ARGS+=" --speaker $SPEAKER"
69
- [ "$CPU" = false ] && ARGS+=" --cuda"
70
- [ "$CPU" = false ] && ARGS+=" --cudnn-benchmark"
71
- [ "$AMP" = true ] && ARGS+=" --amp"
72
- [ "$TORCHSCRIPT" = true ] && ARGS+=" --torchscript"
73
- [ -n "$HIFIGAN" ] && ARGS+=" --hifigan $HIFIGAN"
74
- [ -n "$WAVEGLOW" ] && ARGS+=" --waveglow $WAVEGLOW"
75
- [ -n "$FASTPITCH" ] && ARGS+=" --fastpitch $FASTPITCH"
76
- [ "$PHONE" = true ] && ARGS+=" --p-arpabet 1.0"
77
-
78
- python inference.py $ARGS "$@"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/prepare_dataset.sh DELETED
@@ -1,19 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- set -e
4
-
5
- : ${DATA_DIR:=ALL_SAMI}
6
- : ${ARGS="--extract-mels"}
7
-
8
- python prepare_dataset.py \
9
- --wav-text-filelists filelists/all_sami_filelist_shuf_200_train.txt \
10
- --n-workers 8 \
11
- --batch-size 1 \
12
- --dataset-path $DATA_DIR \
13
- --extract-pitch \
14
- --f0-method pyin \
15
- --pitch_mean 150\
16
- --pitch_std 40\
17
- --n-speakers 10 \
18
- --n-languages 3 \
19
- $ARGS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/train.sh DELETED
@@ -1,100 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- export OMP_NUM_THREADS=1
4
-
5
- : ${NUM_GPUS:=8}
6
- : ${BATCH_SIZE:=16}
7
- : ${GRAD_ACCUMULATION:=2}
8
- : ${OUTPUT_DIR:="./output"}
9
- : ${LOG_FILE:=$OUTPUT_DIR/nvlog.json}
10
- : ${DATASET_PATH:=LJSpeech-1.1}
11
- : ${TRAIN_FILELIST:=filelists/ljs_audio_pitch_text_train_v3.txt}
12
- : ${VAL_FILELIST:=filelists/ljs_audio_pitch_text_val.txt}
13
- : ${AMP:=false}
14
- : ${SEED:=""}
15
-
16
- : ${LEARNING_RATE:=0.1}
17
-
18
- # Adjust these when the amount of data changes
19
- : ${EPOCHS:=1000}
20
- : ${EPOCHS_PER_CHECKPOINT:=20}
21
- : ${WARMUP_STEPS:=1000}
22
- : ${KL_LOSS_WARMUP:=100}
23
-
24
- # Train a mixed phoneme/grapheme model
25
- : ${PHONE:=true}
26
- # Enable energy conditioning
27
- : ${ENERGY:=true}
28
- : ${TEXT_CLEANERS:=english_cleaners_v2}
29
- # Add dummy space prefix/suffix is audio is not precisely trimmed
30
- : ${APPEND_SPACES:=false}
31
-
32
- : ${LOAD_PITCH_FROM_DISK:=true}
33
- : ${LOAD_MEL_FROM_DISK:=false}
34
-
35
- # For multispeaker models, add speaker ID = {0, 1, ...} as the last filelist column
36
- : ${NSPEAKERS:=1}
37
- : ${SAMPLING_RATE:=22050}
38
-
39
- # Adjust env variables to maintain the global batch size: NUM_GPUS x BATCH_SIZE x GRAD_ACCUMULATION = 256.
40
- GBS=$(($NUM_GPUS * $BATCH_SIZE * $GRAD_ACCUMULATION))
41
- [ $GBS -ne 256 ] && echo -e "\nWARNING: Global batch size changed from 256 to ${GBS}."
42
- echo -e "\nAMP=$AMP, ${NUM_GPUS}x${BATCH_SIZE}x${GRAD_ACCUMULATION}" \
43
- "(global batch size ${GBS})\n"
44
-
45
- ARGS=""
46
- ARGS+=" --cuda"
47
- ARGS+=" -o $OUTPUT_DIR"
48
- ARGS+=" --log-file $LOG_FILE"
49
- ARGS+=" --dataset-path $DATASET_PATH"
50
- ARGS+=" --training-files $TRAIN_FILELIST"
51
- ARGS+=" --validation-files $VAL_FILELIST"
52
- ARGS+=" -bs $BATCH_SIZE"
53
- ARGS+=" --grad-accumulation $GRAD_ACCUMULATION"
54
- ARGS+=" --optimizer lamb"
55
- ARGS+=" --epochs $EPOCHS"
56
- ARGS+=" --epochs-per-checkpoint $EPOCHS_PER_CHECKPOINT"
57
- ARGS+=" --resume"
58
- ARGS+=" --warmup-steps $WARMUP_STEPS"
59
- ARGS+=" -lr $LEARNING_RATE"
60
- ARGS+=" --weight-decay 1e-6"
61
- ARGS+=" --grad-clip-thresh 1000.0"
62
- ARGS+=" --dur-predictor-loss-scale 0.1"
63
- ARGS+=" --pitch-predictor-loss-scale 0.1"
64
- ARGS+=" --trainloader-repeats 100"
65
- ARGS+=" --validation-freq 10"
66
-
67
- # Autoalign & new features
68
- ARGS+=" --kl-loss-start-epoch 0"
69
- ARGS+=" --kl-loss-warmup-epochs $KL_LOSS_WARMUP"
70
- ARGS+=" --text-cleaners $TEXT_CLEANERS"
71
- ARGS+=" --n-speakers $NSPEAKERS"
72
-
73
- [ "$AMP" = "true" ] && ARGS+=" --amp"
74
- [ "$PHONE" = "true" ] && ARGS+=" --p-arpabet 1.0"
75
- [ "$ENERGY" = "true" ] && ARGS+=" --energy-conditioning"
76
- [ "$SEED" != "" ] && ARGS+=" --seed $SEED"
77
- [ "$LOAD_MEL_FROM_DISK" = true ] && ARGS+=" --load-mel-from-disk"
78
- [ "$LOAD_PITCH_FROM_DISK" = true ] && ARGS+=" --load-pitch-from-disk"
79
- [ "$PITCH_ONLINE_DIR" != "" ] && ARGS+=" --pitch-online-dir $PITCH_ONLINE_DIR" # e.g., /dev/shm/pitch
80
- [ "$PITCH_ONLINE_METHOD" != "" ] && ARGS+=" --pitch-online-method $PITCH_ONLINE_METHOD"
81
- [ "$APPEND_SPACES" = true ] && ARGS+=" --prepend-space-to-text"
82
- [ "$APPEND_SPACES" = true ] && ARGS+=" --append-space-to-text"
83
-
84
- if [ "$SAMPLING_RATE" == "44100" ]; then
85
- ARGS+=" --sampling-rate 44100"
86
- ARGS+=" --filter-length 2048"
87
- ARGS+=" --hop-length 512"
88
- ARGS+=" --win-length 2048"
89
- ARGS+=" --mel-fmin 0.0"
90
- ARGS+=" --mel-fmax 22050.0"
91
-
92
- elif [ "$SAMPLING_RATE" != "22050" ]; then
93
- echo "Unknown sampling rate $SAMPLING_RATE"
94
- exit 1
95
- fi
96
-
97
- mkdir -p "$OUTPUT_DIR"
98
-
99
- : ${DISTRIBUTED:="-m torch.distributed.launch --nproc_per_node $NUM_GPUS"}
100
- python $DISTRIBUTED train.py $ARGS "$@"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/train_multilang.sh DELETED
@@ -1,110 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- export OMP_NUM_THREADS=1
4
-
5
- : ${NUM_GPUS:=1}
6
- : ${BATCH_SIZE:=1}
7
- : ${GRAD_ACCUMULATION:=32}
8
- : ${OUTPUT_DIR:="./output_multilang"}
9
- : ${LOG_FILE:=$OUTPUT_DIR/nvlog.json}
10
- : ${DATASET_PATH:=ALL_SAMI}
11
- #: ${DATASET_PATH:=mikal_urheim}
12
- #: ${TRAIN_FILELIST:=filelists/smj_sander_text_noshorts_shuff_pitch.txt}
13
- #: ${TRAIN_FILELIST:=filelists/mikal_urheim_pitch_shuf.txt}
14
- : ${TRAIN_FILELIST:=filelists/all_sami_filelist_shuf_200_train.txt}
15
- #: ${VAL_FILELIST:=filelists/smj_sander_text_noshorts_shuff_val_pitch.txt}
16
- #: ${VAL_FILELIST:=filelists/mikal_urheim_pitch_shuf_val.txt}
17
- : ${VAL_FILELIST:=filelists/all_sami_filelist_shuf_200_val.txt}
18
- : ${AMP:=false}
19
- : ${SEED:=""}
20
-
21
- : ${LEARNING_RATE:=0.1}
22
-
23
- # Adjust these when the amount of data changes
24
- : ${EPOCHS:=1000}
25
- : ${EPOCHS_PER_CHECKPOINT:=10}
26
- : ${WARMUP_STEPS:=1000}
27
- : ${KL_LOSS_WARMUP:=100}
28
-
29
- # Train a mixed phoneme/grapheme model
30
- : ${PHONE:=false}
31
- # Enable energy conditioning
32
- : ${ENERGY:=true}
33
- : ${TEXT_CLEANERS:=basic_cleaners}
34
- : ${SYMBOL_SET:=all_sami}
35
- # Add dummy space prefix/suffix is audio is not precisely trimmed
36
- : ${APPEND_SPACES:=false}
37
-
38
- : ${LOAD_PITCH_FROM_DISK:=true} # was true
39
- : ${LOAD_MEL_FROM_DISK:=false}
40
-
41
- # For multispeaker models, add speaker ID = {0, 1, ...} as the last filelist column
42
- : ${NSPEAKERS:=10} # 10
43
- : ${NLANGUAGES:=3} # 3
44
- : ${SAMPLING_RATE:=22050}
45
-
46
- # Adjust env variables to maintain the global batch size: NUM_GPUS x BATCH_SIZE x GRAD_ACCUMULATION = 256.
47
- GBS=$(($NUM_GPUS * $BATCH_SIZE * $GRAD_ACCUMULATION))
48
- [ $GBS -ne 256 ] && echo -e "\nWARNING: Global batch size changed from 256 to ${GBS}."
49
- echo -e "\nAMP=$AMP, ${NUM_GPUS}x${BATCH_SIZE}x${GRAD_ACCUMULATION}" \
50
- "(global batch size ${GBS})\n"
51
-
52
- ARGS=""
53
- ARGS+=" --cuda"
54
- ARGS+=" -o $OUTPUT_DIR"
55
- ARGS+=" --log-file $LOG_FILE"
56
- ARGS+=" --dataset-path $DATASET_PATH"
57
- ARGS+=" --training-files $TRAIN_FILELIST"
58
- ARGS+=" --validation-files $VAL_FILELIST"
59
- ARGS+=" -bs $BATCH_SIZE"
60
- ARGS+=" --grad-accumulation $GRAD_ACCUMULATION"
61
- ARGS+=" --optimizer lamb" #adam
62
- ARGS+=" --epochs $EPOCHS"
63
- ARGS+=" --epochs-per-checkpoint $EPOCHS_PER_CHECKPOINT"
64
- ARGS+=" --resume"
65
- ARGS+=" --warmup-steps $WARMUP_STEPS"
66
- ARGS+=" -lr $LEARNING_RATE"
67
- ARGS+=" --weight-decay 1e-6"
68
- ARGS+=" --grad-clip-thresh 1000.0"
69
- ARGS+=" --dur-predictor-loss-scale 0.1"
70
- ARGS+=" --pitch-predictor-loss-scale 0.1"
71
- ARGS+=" --trainloader-repeats 100"
72
- ARGS+=" --validation-freq 1" #10
73
-
74
- # Autoalign & new features
75
- ARGS+=" --kl-loss-start-epoch 0"
76
- ARGS+=" --kl-loss-warmup-epochs $KL_LOSS_WARMUP"
77
- ARGS+=" --text-cleaners $TEXT_CLEANERS"
78
- ARGS+=" --n-speakers $NSPEAKERS"
79
- ARGS+=" --n-languages $NLANGUAGES"
80
- ARGS+=" --symbol-set $SYMBOL_SET"
81
-
82
- [ "$AMP" = "true" ] && ARGS+=" --amp"
83
- [ "$PHONE" = "true" ] && ARGS+=" --p-arpabet 1.0"
84
- [ "$ENERGY" = "true" ] && ARGS+=" --energy-conditioning"
85
- [ "$SEED" != "" ] && ARGS+=" --seed $SEED"
86
- [ "$LOAD_MEL_FROM_DISK" = true ] && ARGS+=" --load-mel-from-disk"
87
- [ "$LOAD_PITCH_FROM_DISK" = true ] && ARGS+=" --load-pitch-from-disk"
88
- [ "$PITCH_ONLINE_DIR" != "" ] && ARGS+=" --pitch-online-dir $PITCH_ONLINE_DIR" # e.g., /dev/shm/pitch
89
- [ "$PITCH_ONLINE_METHOD" != "" ] && ARGS+=" --pitch-online-method $PITCH_ONLINE_METHOD"
90
- [ "$APPEND_SPACES" = true ] && ARGS+=" --prepend-space-to-text"
91
- [ "$APPEND_SPACES" = true ] && ARGS+=" --append-space-to-text"
92
-
93
- if [ "$SAMPLING_RATE" == "44100" ]; then
94
- ARGS+=" --sampling-rate 44100"
95
- ARGS+=" --filter-length 2048"
96
- ARGS+=" --hop-length 512"
97
- ARGS+=" --win-length 2048"
98
- ARGS+=" --mel-fmin 0.0"
99
- ARGS+=" --mel-fmax 22050.0"
100
-
101
- elif [ "$SAMPLING_RATE" != "22050" ]; then
102
- echo "Unknown sampling rate $SAMPLING_RATE"
103
- exit 1
104
- fi
105
-
106
- mkdir -p "$OUTPUT_DIR"
107
-
108
- : ${DISTRIBUTED:="-m torch.distributed.launch --nproc_per_node $NUM_GPUS"}
109
- #python $DISTRIBUTED train.py $ARGS "$@"
110
- python train_1_with_plot_multilang.py $ARGS "$@"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_1_with_plot_multilang.py DELETED
@@ -1,593 +0,0 @@
1
- # *****************************************************************************
2
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
- #
4
- # Redistribution and use in source and binary forms, with or without
5
- # modification, are permitted provided that the following conditions are met:
6
- # * Redistributions of source code must retain the above copyright
7
- # notice, this list of conditions and the following disclaimer.
8
- # * Redistributions in binary form must reproduce the above copyright
9
- # notice, this list of conditions and the following disclaimer in the
10
- # documentation and/or other materials provided with the distribution.
11
- # * Neither the name of the NVIDIA CORPORATION nor the
12
- # names of its contributors may be used to endorse or promote products
13
- # derived from this software without specific prior written permission.
14
- #
15
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
- # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
- # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19
- # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
- # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
- # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22
- # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
- # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
- #
26
- # *****************************************************************************
27
-
28
- import argparse
29
- import copy
30
- import os
31
- import time
32
- from collections import defaultdict, OrderedDict
33
- from itertools import cycle
34
-
35
- import numpy as np
36
- import torch
37
- import torch.distributed as dist
38
- import amp_C
39
- from apex.optimizers import FusedAdam, FusedLAMB
40
- from torch.nn.parallel import DistributedDataParallel
41
- from torch.utils.data import DataLoader
42
- from torch.utils.data.distributed import DistributedSampler
43
-
44
- import common.tb_dllogger as logger
45
- import models
46
- from common.tb_dllogger import log
47
- from common.repeated_dataloader import (RepeatedDataLoader,
48
- RepeatedDistributedSampler)
49
- from common.text import cmudict
50
- from common.utils import BenchmarkStats, Checkpointer, prepare_tmp
51
- from fastpitch.attn_loss_function import AttentionBinarizationLoss
52
- from fastpitch.data_function import batch_to_gpu, TTSCollate, TTSDataset
53
- from fastpitch.loss_function import FastPitchLoss
54
-
55
- import matplotlib.pyplot as plt
56
-
57
- def parse_args(parser):
58
- parser.add_argument('-o', '--output', type=str, required=True,
59
- help='Directory to save checkpoints')
60
- parser.add_argument('-d', '--dataset-path', type=str, default='./',
61
- help='Path to dataset')
62
- parser.add_argument('--log-file', type=str, default=None,
63
- help='Path to a DLLogger log file')
64
-
65
- train = parser.add_argument_group('training setup')
66
- train.add_argument('--epochs', type=int, required=True,
67
- help='Number of total epochs to run')
68
- train.add_argument('--epochs-per-checkpoint', type=int, default=50,
69
- help='Number of epochs per checkpoint')
70
- train.add_argument('--checkpoint-path', type=str, default=None,
71
- help='Checkpoint path to resume training')
72
- train.add_argument('--keep-milestones', default=list(range(100, 1000, 100)),
73
- type=int, nargs='+',
74
- help='Milestone checkpoints to keep from removing')
75
- train.add_argument('--resume', action='store_true',
76
- help='Resume training from the last checkpoint')
77
- train.add_argument('--seed', type=int, default=1234,
78
- help='Seed for PyTorch random number generators')
79
- train.add_argument('--amp', action='store_true',
80
- help='Enable AMP')
81
- train.add_argument('--cuda', action='store_true',
82
- help='Run on GPU using CUDA')
83
- train.add_argument('--cudnn-benchmark', action='store_true',
84
- help='Enable cudnn benchmark mode')
85
- train.add_argument('--ema-decay', type=float, default=0,
86
- help='Discounting factor for training weights EMA')
87
- train.add_argument('--grad-accumulation', type=int, default=1,
88
- help='Training steps to accumulate gradients for')
89
- train.add_argument('--kl-loss-start-epoch', type=int, default=250,
90
- help='Start adding the hard attention loss term')
91
- train.add_argument('--kl-loss-warmup-epochs', type=int, default=100,
92
- help='Gradually increase the hard attention loss term')
93
- train.add_argument('--kl-loss-weight', type=float, default=1.0,
94
- help='Gradually increase the hard attention loss term')
95
- train.add_argument('--benchmark-epochs-num', type=int, default=20,
96
- help='Number of epochs for calculating final stats')
97
- train.add_argument('--validation-freq', type=int, default=1,
98
- help='Validate every N epochs to use less compute')
99
-
100
- opt = parser.add_argument_group('optimization setup')
101
- opt.add_argument('--optimizer', type=str, default='lamb',
102
- help='Optimization algorithm')
103
- opt.add_argument('-lr', '--learning-rate', type=float, required=True,
104
- help='Learing rate')
105
- opt.add_argument('--weight-decay', default=1e-6, type=float,
106
- help='Weight decay')
107
- opt.add_argument('--grad-clip-thresh', default=1000.0, type=float,
108
- help='Clip threshold for gradients')
109
- opt.add_argument('-bs', '--batch-size', type=int, required=True,
110
- help='Batch size per GPU')
111
- opt.add_argument('--warmup-steps', type=int, default=1000,
112
- help='Number of steps for lr warmup')
113
- opt.add_argument('--dur-predictor-loss-scale', type=float,
114
- default=1.0, help='Rescale duration predictor loss')
115
- opt.add_argument('--pitch-predictor-loss-scale', type=float,
116
- default=1.0, help='Rescale pitch predictor loss')
117
- opt.add_argument('--attn-loss-scale', type=float,
118
- default=1.0, help='Rescale alignment loss')
119
-
120
- data = parser.add_argument_group('dataset parameters')
121
- data.add_argument('--training-files', type=str, nargs='*', required=True,
122
- help='Paths to training filelists.')
123
- data.add_argument('--validation-files', type=str, nargs='*',
124
- required=True, help='Paths to validation filelists')
125
- data.add_argument('--text-cleaners', nargs='*',
126
- default=['english_cleaners'], type=str,
127
- help='Type of text cleaners for input text')
128
- data.add_argument('--symbol-set', type=str, default='english_basic',
129
- help='Define symbol set for input text')
130
- data.add_argument('--p-arpabet', type=float, default=0.0,
131
- help='Probability of using arpabets instead of graphemes '
132
- 'for each word; set 0 for pure grapheme training')
133
- data.add_argument('--heteronyms-path', type=str, default='cmudict/heteronyms',
134
- help='Path to the list of heteronyms')
135
- data.add_argument('--cmudict-path', type=str, default='cmudict/cmudict-0.7b',
136
- help='Path to the pronouncing dictionary')
137
- data.add_argument('--prepend-space-to-text', action='store_true',
138
- help='Capture leading silence with a space token')
139
- data.add_argument('--append-space-to-text', action='store_true',
140
- help='Capture trailing silence with a space token')
141
- data.add_argument('--num-workers', type=int, default=2, # 6
142
- help='Subprocesses for train and val DataLoaders')
143
- data.add_argument('--trainloader-repeats', type=int, default=100,
144
- help='Repeats the dataset to prolong epochs')
145
-
146
- cond = parser.add_argument_group('data for conditioning')
147
- cond.add_argument('--n-speakers', type=int, default=1,
148
- help='Number of speakers in the dataset. '
149
- 'n_speakers > 1 enables speaker embeddings')
150
- # ANT: added language
151
- cond.add_argument('--n-languages', type=int, default=1,
152
- help='Number of languages in the dataset. '
153
- 'n_languages > 1 enables language embeddings')
154
-
155
- cond.add_argument('--load-pitch-from-disk', action='store_true',
156
- help='Use pitch cached on disk with prepare_dataset.py')
157
- cond.add_argument('--pitch-online-method', default='pyin',
158
- choices=['pyin'],
159
- help='Calculate pitch on the fly during trainig')
160
- cond.add_argument('--pitch-online-dir', type=str, default=None,
161
- help='A directory for storing pitch calculated on-line')
162
- cond.add_argument('--pitch-mean', type=float, default=125.626816, #default=214.72203,
163
- help='Normalization value for pitch')
164
- cond.add_argument('--pitch-std', type=float, default=37.52, #default=65.72038,
165
- help='Normalization value for pitch')
166
- cond.add_argument('--load-mel-from-disk', action='store_true',
167
- help='Use mel-spectrograms cache on the disk') # XXX
168
-
169
- audio = parser.add_argument_group('audio parameters')
170
- audio.add_argument('--max-wav-value', default=32768.0, type=float,
171
- help='Maximum audiowave value')
172
- audio.add_argument('--sampling-rate', default=22050, type=int,
173
- help='Sampling rate')
174
- audio.add_argument('--filter-length', default=1024, type=int,
175
- help='Filter length')
176
- audio.add_argument('--hop-length', default=256, type=int,
177
- help='Hop (stride) length')
178
- audio.add_argument('--win-length', default=1024, type=int,
179
- help='Window length')
180
- audio.add_argument('--mel-fmin', default=0.0, type=float,
181
- help='Minimum mel frequency')
182
- audio.add_argument('--mel-fmax', default=8000.0, type=float,
183
- help='Maximum mel frequency')
184
-
185
- dist = parser.add_argument_group('distributed setup')
186
- dist.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0),
187
- help='Rank of the process for multiproc; do not set manually')
188
- dist.add_argument('--world_size', type=int, default=os.getenv('WORLD_SIZE', 1),
189
- help='Number of processes for multiproc; do not set manually')
190
- return parser
191
-
192
-
193
- def reduce_tensor(tensor, num_gpus):
194
- rt = tensor.clone()
195
- dist.all_reduce(rt, op=dist.ReduceOp.SUM)
196
- return rt.true_divide(num_gpus)
197
-
198
-
199
- def init_distributed(args, world_size, rank):
200
- assert torch.cuda.is_available(), "Distributed mode requires CUDA."
201
- print("Initializing distributed training")
202
-
203
- # Set cuda device so everything is done on the right GPU.
204
- torch.cuda.set_device(rank % torch.cuda.device_count())
205
-
206
- # Initialize distributed communication
207
- dist.init_process_group(backend=('nccl' if args.cuda else 'gloo'),
208
- init_method='env://')
209
- print("Done initializing distributed training")
210
-
211
-
212
- def validate(model, epoch, total_iter, criterion, val_loader, distributed_run,
213
- batch_to_gpu, local_rank, ema=False):
214
- was_training = model.training
215
- model.eval()
216
-
217
- tik = time.perf_counter()
218
- with torch.no_grad():
219
- val_meta = defaultdict(float)
220
- val_num_frames = 0
221
- for i, batch in enumerate(val_loader):
222
- x, y, num_frames = batch_to_gpu(batch)
223
- y_pred = model(x)
224
- loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum')
225
-
226
- if distributed_run:
227
- for k, v in meta.items():
228
- val_meta[k] += reduce_tensor(v, 1)
229
- val_num_frames += reduce_tensor(num_frames.data, 1).item()
230
- else:
231
- for k, v in meta.items():
232
- val_meta[k] += v
233
- val_num_frames += num_frames.item()
234
-
235
- # NOTE: ugly patch to visualize the first utterance of the validation corpus.
236
- # The goal is to determine if the training is progressing properly
237
- if (i == 0) and (local_rank == 0) and (not ema):
238
- # Plot some debug information
239
- fig, axs = plt.subplots(2, 2, figsize=(21,14))
240
-
241
- # - Mel-spectrogram
242
- pred_mel = y_pred[0][0, :, :].cpu().detach().numpy().astype(np.float32).T
243
- orig_mel = y[0][0, :, :].cpu().detach().numpy().astype(np.float32)
244
- axs[0,0].imshow(orig_mel, aspect='auto', origin='lower', interpolation='nearest')
245
- axs[1,0].imshow(pred_mel, aspect='auto', origin='lower', interpolation='nearest')
246
-
247
- # Prosody
248
- f0_pred = y_pred[4][0, :].cpu().detach().numpy().astype(np.float32)
249
- f0_ori = y_pred[5][0, :].cpu().detach().numpy().astype(np.float32)
250
- axs[1,1].plot(f0_ori)
251
- axs[1,1].plot(f0_pred)
252
-
253
- # # Duration
254
- # att_pred = y_pred[2][0, :].cpu().detach().numpy().astype(np.float32)
255
- # att_ori = x[7][0,:].cpu().detach().numpy().astype(np.float32)
256
- # axs[0,1].imshow(att_ori, aspect='auto', origin='lower', interpolation='nearest')
257
-
258
- if not os.path.exists("debug_epoch/"):
259
- os.makedirs("debug_epoch_laila/")
260
-
261
- fig.savefig(f'debug_epoch/{epoch:06d}.png', bbox_inches='tight')
262
-
263
- val_meta = {k: v / len(val_loader.dataset) for k, v in val_meta.items()}
264
-
265
- val_meta['took'] = time.perf_counter() - tik
266
-
267
- log((epoch,) if epoch is not None else (), tb_total_steps=total_iter,
268
- subset='val_ema' if ema else 'val',
269
- data=OrderedDict([
270
- ('loss', val_meta['loss'].item()),
271
- ('mel_loss', val_meta['mel_loss'].item()),
272
- ('frames/s', val_num_frames / val_meta['took']),
273
- ('took', val_meta['took'])]),
274
- )
275
-
276
- if was_training:
277
- model.train()
278
- return val_meta
279
-
280
-
281
- def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None):
282
- if warmup_iters == 0:
283
- scale = 1.0
284
- elif total_iter > warmup_iters:
285
- scale = 1. / (total_iter ** 0.5)
286
- else:
287
- scale = total_iter / (warmup_iters ** 1.5)
288
-
289
- for param_group in opt.param_groups:
290
- param_group['lr'] = learning_rate * scale
291
-
292
-
293
- def apply_ema_decay(model, ema_model, decay):
294
- if not decay:
295
- return
296
- st = model.state_dict()
297
- add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
298
- for k, v in ema_model.state_dict().items():
299
- if add_module and not k.startswith('module.'):
300
- k = 'module.' + k
301
- v.copy_(decay * v + (1 - decay) * st[k])
302
-
303
-
304
- def init_multi_tensor_ema(model, ema_model):
305
- model_weights = list(model.state_dict().values())
306
- ema_model_weights = list(ema_model.state_dict().values())
307
- ema_overflow_buf = torch.cuda.IntTensor([0])
308
- return model_weights, ema_model_weights, ema_overflow_buf
309
-
310
-
311
- def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
312
- amp_C.multi_tensor_axpby(
313
- 65536, overflow_buf, [ema_weights, model_weights, ema_weights],
314
- decay, 1-decay, -1)
315
-
316
-
317
- def main():
318
- parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
319
- allow_abbrev=False)
320
- parser = parse_args(parser)
321
- args, _ = parser.parse_known_args()
322
-
323
- if args.p_arpabet > 0.0:
324
- cmudict.initialize(args.cmudict_path, args.heteronyms_path)
325
-
326
- distributed_run = args.world_size > 1
327
-
328
- torch.manual_seed(args.seed + args.local_rank)
329
- np.random.seed(args.seed + args.local_rank)
330
-
331
- if args.local_rank == 0:
332
- if not os.path.exists(args.output):
333
- os.makedirs(args.output)
334
-
335
- log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
336
- tb_subsets = ['train', 'val']
337
- if args.ema_decay > 0.0:
338
- tb_subsets.append('val_ema')
339
-
340
- logger.init(log_fpath, args.output, enabled=(args.local_rank == 0),
341
- tb_subsets=tb_subsets)
342
- logger.parameters(vars(args), tb_subset='train')
343
-
344
- parser = models.parse_model_args('FastPitch', parser)
345
- args, unk_args = parser.parse_known_args()
346
- if len(unk_args) > 0:
347
- raise ValueError(f'Invalid options {unk_args}')
348
-
349
- torch.backends.cudnn.benchmark = args.cudnn_benchmark
350
-
351
- if distributed_run:
352
- init_distributed(args, args.world_size, args.local_rank)
353
- else:
354
- if args.trainloader_repeats > 1:
355
- print('WARNING: Disabled --trainloader-repeats, supported only for'
356
- ' multi-GPU data loading.')
357
- args.trainloader_repeats = 1
358
-
359
- device = torch.device('cuda' if args.cuda else 'cpu')
360
- model_config = models.get_model_config('FastPitch', args)
361
- model = models.get_model('FastPitch', model_config, device)
362
-
363
- attention_kl_loss = AttentionBinarizationLoss()
364
-
365
- # Store pitch mean/std as params to translate from Hz during inference
366
- model.pitch_mean[0] = args.pitch_mean
367
- model.pitch_std[0] = args.pitch_std
368
-
369
- kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9,
370
- weight_decay=args.weight_decay)
371
- if args.optimizer == 'adam':
372
- optimizer = FusedAdam(model.parameters(), **kw)
373
- # optimizer = torch.optim.Adam(model.parameters(), **kw)
374
- elif args.optimizer == 'lamb':
375
-
376
- optimizer = FusedLAMB(model.parameters(), **kw)
377
- # optimizer = torch.optim.Adam(model.parameters(), **kw)
378
- else:
379
- raise ValueError
380
-
381
- scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
382
-
383
- if args.ema_decay > 0:
384
- ema_model = copy.deepcopy(model)
385
- else:
386
- ema_model = None
387
-
388
- if distributed_run:
389
- model = DistributedDataParallel(
390
- model, device_ids=[args.local_rank], output_device=args.local_rank,
391
- find_unused_parameters=True)
392
-
393
- train_state = {'epoch': 1, 'total_iter': 1}
394
- checkpointer = Checkpointer(args.output, args.keep_milestones)
395
-
396
- checkpointer.maybe_load(model, optimizer, scaler, train_state, args,
397
- ema_model)
398
-
399
- start_epoch = train_state['epoch']
400
- total_iter = train_state['total_iter']
401
-
402
- criterion = FastPitchLoss(
403
- dur_predictor_loss_scale=args.dur_predictor_loss_scale,
404
- pitch_predictor_loss_scale=args.pitch_predictor_loss_scale,
405
- attn_loss_scale=args.attn_loss_scale)
406
-
407
- collate_fn = TTSCollate()
408
-
409
- if args.local_rank == 0:
410
- prepare_tmp(args.pitch_online_dir)
411
-
412
- trainset = TTSDataset(audiopaths_and_text=args.training_files, **vars(args))
413
- valset = TTSDataset(audiopaths_and_text=args.validation_files, **vars(args))
414
-
415
- if distributed_run:
416
- train_sampler = RepeatedDistributedSampler(args.trainloader_repeats,
417
- trainset, drop_last=True)
418
- val_sampler = DistributedSampler(valset)
419
- shuffle = False
420
- else:
421
- train_sampler, val_sampler, shuffle = None, None, False ########### was True
422
-
423
- # 4 workers are optimal on DGX-1 (from epoch 2 onwards)
424
- kw = {'num_workers': args.num_workers, 'batch_size': args.batch_size,
425
- 'collate_fn': collate_fn}
426
- train_loader = RepeatedDataLoader(args.trainloader_repeats, trainset,
427
- shuffle=shuffle, drop_last=True,
428
- sampler=train_sampler, pin_memory=True,
429
- persistent_workers=True, **kw)
430
- val_loader = DataLoader(valset, shuffle=False, sampler=val_sampler,
431
- pin_memory=False, **kw)
432
- if args.ema_decay:
433
- mt_ema_params = init_multi_tensor_ema(model, ema_model)
434
-
435
- model.train()
436
- bmark_stats = BenchmarkStats()
437
-
438
- torch.cuda.synchronize()
439
- for epoch in range(start_epoch, args.epochs + 1):
440
- epoch_start_time = time.perf_counter()
441
-
442
- epoch_loss = 0.0
443
- epoch_mel_loss = 0.0
444
- epoch_num_frames = 0
445
- epoch_frames_per_sec = 0.0
446
-
447
- if distributed_run:
448
- train_loader.sampler.set_epoch(epoch)
449
-
450
- iter_loss = 0
451
- iter_num_frames = 0
452
- iter_meta = {}
453
- iter_start_time = time.perf_counter()
454
-
455
- epoch_iter = 1
456
- for batch, accum_step in zip(train_loader,
457
- cycle(range(1, args.grad_accumulation + 1))):
458
- if accum_step == 1:
459
- adjust_learning_rate(total_iter, optimizer, args.learning_rate,
460
- args.warmup_steps)
461
-
462
- model.zero_grad(set_to_none=True)
463
-
464
- x, y, num_frames = batch_to_gpu(batch)
465
-
466
- with torch.cuda.amp.autocast(enabled=args.amp):
467
- y_pred = model(x)
468
- loss, meta = criterion(y_pred, y)
469
-
470
- if (args.kl_loss_start_epoch is not None
471
- and epoch >= args.kl_loss_start_epoch):
472
-
473
- if args.kl_loss_start_epoch == epoch and epoch_iter == 1:
474
- print('Begin hard_attn loss')
475
-
476
- _, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred
477
- binarization_loss = attention_kl_loss(attn_hard, attn_soft)
478
- kl_weight = min((epoch - args.kl_loss_start_epoch) / args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight
479
- meta['kl_loss'] = binarization_loss.clone().detach() * kl_weight
480
- loss += kl_weight * binarization_loss
481
-
482
- else:
483
- meta['kl_loss'] = torch.zeros_like(loss)
484
- kl_weight = 0
485
- binarization_loss = 0
486
-
487
- loss /= args.grad_accumulation
488
- meta = {k: v / args.grad_accumulation
489
- for k, v in meta.items()}
490
-
491
- if args.amp:
492
- scaler.scale(loss).backward()
493
- else:
494
- loss.backward()
495
-
496
- if distributed_run:
497
- reduced_loss = reduce_tensor(loss.data, args.world_size).item()
498
- reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
499
- meta = {k: reduce_tensor(v, args.world_size) for k, v in meta.items()}
500
- else:
501
- reduced_loss = loss.item()
502
- reduced_num_frames = num_frames.item()
503
- if np.isnan(reduced_loss):
504
- raise Exception("loss is NaN")
505
-
506
- iter_loss += reduced_loss
507
- iter_num_frames += reduced_num_frames
508
- iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}
509
- if accum_step % args.grad_accumulation == 0:
510
-
511
- logger.log_grads_tb(total_iter, model)
512
- if args.amp:
513
- scaler.unscale_(optimizer)
514
- torch.nn.utils.clip_grad_norm_(
515
- model.parameters(), args.grad_clip_thresh)
516
- scaler.step(optimizer)
517
- scaler.update()
518
- else:
519
- torch.nn.utils.clip_grad_norm_(
520
- model.parameters(), args.grad_clip_thresh)
521
- optimizer.step()
522
-
523
- if args.ema_decay > 0.0:
524
- apply_multi_tensor_ema(args.ema_decay, *mt_ema_params)
525
-
526
- iter_mel_loss = iter_meta['mel_loss'].item()
527
- iter_kl_loss = iter_meta['kl_loss'].item()
528
- iter_time = time.perf_counter() - iter_start_time
529
- epoch_frames_per_sec += iter_num_frames / iter_time
530
- epoch_loss += iter_loss
531
- epoch_num_frames += iter_num_frames
532
- epoch_mel_loss += iter_mel_loss
533
-
534
- num_iters = len(train_loader) // args.grad_accumulation
535
- log((epoch, epoch_iter, num_iters), tb_total_steps=total_iter,
536
- subset='train', data=OrderedDict([
537
- ('loss', iter_loss),
538
- ('mel_loss', iter_mel_loss),
539
- ('kl_loss', iter_kl_loss),
540
- ('kl_weight', kl_weight),
541
- ('frames/s', iter_num_frames / iter_time),
542
- ('took', iter_time),
543
- ('lrate', optimizer.param_groups[0]['lr'])]),
544
- )
545
-
546
- iter_loss = 0
547
- iter_num_frames = 0
548
- iter_meta = {}
549
- iter_start_time = time.perf_counter()
550
-
551
- if epoch_iter == num_iters:
552
- break
553
- epoch_iter += 1
554
- total_iter += 1
555
-
556
- # Finished epoch
557
- epoch_loss /= epoch_iter
558
- epoch_mel_loss /= epoch_iter
559
- epoch_time = time.perf_counter() - epoch_start_time
560
- log((epoch,), tb_total_steps=None, subset='train_avg',
561
- data=OrderedDict([
562
- ('loss', epoch_loss),
563
- ('mel_loss', epoch_mel_loss),
564
- ('frames/s', epoch_num_frames / epoch_time),
565
- ('took', epoch_time)]),
566
- )
567
- bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss,
568
- epoch_time)
569
-
570
- if epoch % args.validation_freq == 0:
571
- validate(model, epoch, total_iter, criterion, val_loader,
572
- distributed_run, batch_to_gpu, ema=False, local_rank=args.local_rank)
573
-
574
- if args.ema_decay > 0:
575
- validate(ema_model, epoch, total_iter, criterion, val_loader,
576
- distributed_run, batch_to_gpu, args.local_rank, ema=True)
577
-
578
- # save before making sched.step() for proper loading of LR
579
- checkpointer.maybe_save(args, model, ema_model, optimizer, scaler,
580
- epoch, total_iter, model_config)
581
- logger.flush()
582
-
583
- # Finished training
584
- if len(bmark_stats) > 0:
585
- log((), tb_total_steps=None, subset='train_avg',
586
- data=bmark_stats.get(args.benchmark_epochs_num))
587
-
588
- validate(model, None, total_iter, criterion, val_loader, distributed_run,
589
- batch_to_gpu)
590
-
591
-
592
- if __name__ == '__main__':
593
- main()