Futuresony commited on
Commit
5d490e8
·
verified ·
1 Parent(s): 24e6612

Update tts.py

Browse files
Files changed (1) hide show
  1. tts.py +0 -173
tts.py CHANGED
@@ -1,176 +1,3 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
-
6
- import os
7
- import re
8
- import tempfile
9
- import torch
10
- import sys
11
- import gradio as gr
12
- import numpy as np
13
-
14
- from huggingface_hub import hf_hub_download
15
-
16
- # Setup TTS env
17
- if "vits" not in sys.path:
18
- sys.path.append("vits")
19
-
20
- from vits import commons, utils
21
- from vits.models import SynthesizerTrn
22
-
23
-
24
- TTS_LANGUAGES = {}
25
- with open(f"data/tts/all_langs.tsv") as f:
26
- for line in f:
27
- iso, name = line.split(" ", 1)
28
- TTS_LANGUAGES[iso.strip()] = name.strip()
29
-
30
-
31
- class TextMapper(object):
32
- def __init__(self, vocab_file):
33
- self.symbols = [
34
- x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()
35
- ]
36
- self.SPACE_ID = self.symbols.index(" ")
37
- self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
38
- self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
39
-
40
- def text_to_sequence(self, text, cleaner_names):
41
- """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
42
- Args:
43
- text: string to convert to a sequence
44
- cleaner_names: names of the cleaner functions to run the text through
45
- Returns:
46
- List of integers corresponding to the symbols in the text
47
- """
48
- sequence = []
49
- clean_text = text.strip()
50
- for symbol in clean_text:
51
- symbol_id = self._symbol_to_id[symbol]
52
- sequence += [symbol_id]
53
- return sequence
54
-
55
- def uromanize(self, text, uroman_pl):
56
- iso = "xxx"
57
- with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2:
58
- with open(tf.name, "w") as f:
59
- f.write("\n".join([text]))
60
- cmd = f"perl " + uroman_pl
61
- cmd += f" -l {iso} "
62
- cmd += f" < {tf.name} > {tf2.name}"
63
- os.system(cmd)
64
- outtexts = []
65
- with open(tf2.name) as f:
66
- for line in f:
67
- line = re.sub(r"\s+", " ", line).strip()
68
- outtexts.append(line)
69
- outtext = outtexts[0]
70
- return outtext
71
-
72
- def get_text(self, text, hps):
73
- text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
74
- if hps.data.add_blank:
75
- text_norm = commons.intersperse(text_norm, 0)
76
- text_norm = torch.LongTensor(text_norm)
77
- return text_norm
78
-
79
- def filter_oov(self, text, lang=None):
80
- text = self.preprocess_char(text, lang=lang)
81
- val_chars = self._symbol_to_id
82
- txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
83
- return txt_filt
84
-
85
- def preprocess_char(self, text, lang=None):
86
- """
87
- Special treatement of characters in certain languages
88
- """
89
- if lang == "ron":
90
- text = text.replace("ț", "ţ")
91
- print(f"{lang} (ț -> ţ): {text}")
92
- return text
93
-
94
-
95
- def synthesize(text=None, lang=None, speed=None):
96
- if speed is None:
97
- speed = 1.0
98
-
99
- lang_code = lang.split()[0].strip()
100
-
101
- vocab_file = hf_hub_download(
102
- repo_id="facebook/mms-tts",
103
- filename="vocab.txt",
104
- subfolder=f"models/{lang_code}",
105
- )
106
- config_file = hf_hub_download(
107
- repo_id="facebook/mms-tts",
108
- filename="config.json",
109
- subfolder=f"models/{lang_code}",
110
- )
111
- g_pth = hf_hub_download(
112
- repo_id="facebook/mms-tts",
113
- filename="G_100000.pth",
114
- subfolder=f"models/{lang_code}",
115
- )
116
-
117
- if torch.cuda.is_available():
118
- device = torch.device("cuda")
119
- elif (
120
- hasattr(torch.backends, "mps")
121
- and torch.backends.mps.is_available()
122
- and torch.backends.mps.is_built()
123
- ):
124
- device = torch.device("mps")
125
- else:
126
- device = torch.device("cpu")
127
-
128
- print(f"Run inference with {device}")
129
-
130
- assert os.path.isfile(config_file), f"{config_file} doesn't exist"
131
- hps = utils.get_hparams_from_file(config_file)
132
- text_mapper = TextMapper(vocab_file)
133
- net_g = SynthesizerTrn(
134
- len(text_mapper.symbols),
135
- hps.data.filter_length // 2 + 1,
136
- hps.train.segment_size // hps.data.hop_length,
137
- **hps.model,
138
- )
139
- net_g.to(device)
140
- _ = net_g.eval()
141
-
142
- _ = utils.load_checkpoint(g_pth, net_g, None)
143
-
144
- is_uroman = hps.data.training_files.split(".")[-1] == "uroman"
145
-
146
- if is_uroman:
147
- uroman_dir = "uroman"
148
- assert os.path.exists(uroman_dir)
149
- uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl")
150
- text = text_mapper.uromanize(text, uroman_pl)
151
-
152
- text = text.lower()
153
- text = text_mapper.filter_oov(text, lang=lang)
154
- stn_tst = text_mapper.get_text(text, hps)
155
- with torch.no_grad():
156
- x_tst = stn_tst.unsqueeze(0).to(device)
157
- x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
158
- hyp = (
159
- net_g.infer(
160
- x_tst,
161
- x_tst_lengths,
162
- noise_scale=0.667,
163
- noise_scale_w=0.8,
164
- length_scale=1.0 / speed,
165
- )[0][0, 0]
166
- .cpu()
167
- .float()
168
- .numpy()
169
- )
170
-
171
- hyp = (hyp * 32768).astype(np.int16)
172
- return (hps.data.sampling_rate, hyp), text
173
-
174
 
175
  TTS_EXAMPLES = [
176
  ["I am going to the store.", "eng (English)", 1.0],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  TTS_EXAMPLES = [
3
  ["I am going to the store.", "eng (English)", 1.0],