SohomToom commited on
Commit
2b7ac8b
·
verified ·
1 Parent(s): 40aad53

Update openvoice/api.py

Browse files
Files changed (1) hide show
  1. openvoice/api.py +202 -202
openvoice/api.py CHANGED
@@ -1,202 +1,202 @@
1
- import torch
2
- import numpy as np
3
- import re
4
- import soundfile
5
- from openvoice import utils
6
- from openvoice import commons
7
- import os
8
- import librosa
9
- from openvoice.text import text_to_sequence
10
- from openvoice.mel_processing import spectrogram_torch
11
- from openvoice.models import SynthesizerTrn
12
-
13
-
14
- class OpenVoiceBaseClass(object):
15
- def __init__(self,
16
- config_path,
17
- device='cuda:0'):
18
- if 'cuda' in device:
19
- assert torch.cuda.is_available()
20
-
21
- hps = utils.get_hparams_from_file(config_path)
22
-
23
- model = SynthesizerTrn(
24
- len(getattr(hps, 'symbols', [])),
25
- hps.data.filter_length // 2 + 1,
26
- n_speakers=hps.data.n_speakers,
27
- **hps.model,
28
- ).to(device)
29
-
30
- model.eval()
31
- self.model = model
32
- self.hps = hps
33
- self.device = device
34
-
35
- def load_ckpt(self, ckpt_path):
36
- checkpoint_dict = torch.load(ckpt_path, map_location=torch.device(self.device))
37
- a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
38
- print("Loaded checkpoint '{}'".format(ckpt_path))
39
- print('missing/unexpected keys:', a, b)
40
-
41
-
42
- class BaseSpeakerTTS(OpenVoiceBaseClass):
43
- language_marks = {
44
- "english": "EN",
45
- "chinese": "ZH",
46
- }
47
-
48
- @staticmethod
49
- def get_text(text, hps, is_symbol):
50
- text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
51
- if hps.data.add_blank:
52
- text_norm = commons.intersperse(text_norm, 0)
53
- text_norm = torch.LongTensor(text_norm)
54
- return text_norm
55
-
56
- @staticmethod
57
- def audio_numpy_concat(segment_data_list, sr, speed=1.):
58
- audio_segments = []
59
- for segment_data in segment_data_list:
60
- audio_segments += segment_data.reshape(-1).tolist()
61
- audio_segments += [0] * int((sr * 0.05)/speed)
62
- audio_segments = np.array(audio_segments).astype(np.float32)
63
- return audio_segments
64
-
65
- @staticmethod
66
- def split_sentences_into_pieces(text, language_str):
67
- texts = utils.split_sentence(text, language_str=language_str)
68
- print(" > Text splitted to sentences.")
69
- print('\n'.join(texts))
70
- print(" > ===========================")
71
- return texts
72
-
73
- def tts(self, text, output_path, speaker, language='English', speed=1.0):
74
- mark = self.language_marks.get(language.lower(), None)
75
- assert mark is not None, f"language {language} is not supported"
76
-
77
- texts = self.split_sentences_into_pieces(text, mark)
78
-
79
- audio_list = []
80
- for t in texts:
81
- t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
82
- t = f'[{mark}]{t}[{mark}]'
83
- stn_tst = self.get_text(t, self.hps, False)
84
- device = self.device
85
- speaker_id = self.hps.speakers[speaker]
86
- with torch.no_grad():
87
- x_tst = stn_tst.unsqueeze(0).to(device)
88
- x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
89
- sid = torch.LongTensor([speaker_id]).to(device)
90
- audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6,
91
- length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
92
- audio_list.append(audio)
93
- audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
94
-
95
- if output_path is None:
96
- return audio
97
- else:
98
- soundfile.write(output_path, audio, self.hps.data.sampling_rate)
99
-
100
-
101
- class ToneColorConverter(OpenVoiceBaseClass):
102
- def __init__(self, *args, **kwargs):
103
- super().__init__(*args, **kwargs)
104
-
105
- if kwargs.get('enable_watermark', True):
106
- import wavmark
107
- self.watermark_model = wavmark.load_model().to(self.device)
108
- else:
109
- self.watermark_model = None
110
- self.version = getattr(self.hps, '_version_', "v1")
111
-
112
-
113
-
114
- def extract_se(self, ref_wav_list, se_save_path=None):
115
- if isinstance(ref_wav_list, str):
116
- ref_wav_list = [ref_wav_list]
117
-
118
- device = self.device
119
- hps = self.hps
120
- gs = []
121
-
122
- for fname in ref_wav_list:
123
- audio_ref, sr = librosa.load(fname, sr=hps.data.sampling_rate)
124
- y = torch.FloatTensor(audio_ref)
125
- y = y.to(device)
126
- y = y.unsqueeze(0)
127
- y = spectrogram_torch(y, hps.data.filter_length,
128
- hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
129
- center=False).to(device)
130
- with torch.no_grad():
131
- g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
132
- gs.append(g.detach())
133
- gs = torch.stack(gs).mean(0)
134
-
135
- if se_save_path is not None:
136
- os.makedirs(os.path.dirname(se_save_path), exist_ok=True)
137
- torch.save(gs.cpu(), se_save_path)
138
-
139
- return gs
140
-
141
- def convert(self, audio_src_path, src_se, tgt_se, output_path=None, tau=0.3, message="default"):
142
- hps = self.hps
143
- # load audio
144
- audio, sample_rate = librosa.load(audio_src_path, sr=hps.data.sampling_rate)
145
- audio = torch.tensor(audio).float()
146
-
147
- with torch.no_grad():
148
- y = torch.FloatTensor(audio).to(self.device)
149
- y = y.unsqueeze(0)
150
- spec = spectrogram_torch(y, hps.data.filter_length,
151
- hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
152
- center=False).to(self.device)
153
- spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device)
154
- audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][
155
- 0, 0].data.cpu().float().numpy()
156
- audio = self.add_watermark(audio, message)
157
- if output_path is None:
158
- return audio
159
- else:
160
- soundfile.write(output_path, audio, hps.data.sampling_rate)
161
-
162
- def add_watermark(self, audio, message):
163
- if self.watermark_model is None:
164
- return audio
165
- device = self.device
166
- bits = utils.string_to_bits(message).reshape(-1)
167
- n_repeat = len(bits) // 32
168
-
169
- K = 16000
170
- coeff = 2
171
- for n in range(n_repeat):
172
- trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
173
- if len(trunck) != K:
174
- print('Audio too short, fail to add watermark')
175
- break
176
- message_npy = bits[n * 32: (n + 1) * 32]
177
-
178
- with torch.no_grad():
179
- signal = torch.FloatTensor(trunck).to(device)[None]
180
- message_tensor = torch.FloatTensor(message_npy).to(device)[None]
181
- signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor)
182
- signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze()
183
- audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy
184
- return audio
185
-
186
- def detect_watermark(self, audio, n_repeat):
187
- bits = []
188
- K = 16000
189
- coeff = 2
190
- for n in range(n_repeat):
191
- trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
192
- if len(trunck) != K:
193
- print('Audio too short, fail to detect watermark')
194
- return 'Fail'
195
- with torch.no_grad():
196
- signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0)
197
- message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze()
198
- bits.append(message_decoded_npy)
199
- bits = np.stack(bits).reshape(-1, 8)
200
- message = utils.bits_to_string(bits)
201
- return message
202
-
 
1
+ import torch
2
+ import numpy as np
3
+ import re
4
+ import soundfile
5
+ from openvoice import utils
6
+ from openvoice import commons
7
+ import os
8
+ import librosa
9
+ from openvoice.text import text_to_sequence
10
+ from openvoice.mel_processing import spectrogram_torch
11
+ from openvoice.models import SynthesizerTrn
12
+
13
+
14
+ class OpenVoiceBaseClass(object):
15
+ def __init__(self,
16
+ config_path,
17
+ device='cuda:0'):
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ print(f"Using device: {device}")
20
+
21
+ hps = utils.get_hparams_from_file(config_path)
22
+
23
+ model = SynthesizerTrn(
24
+ len(getattr(hps, 'symbols', [])),
25
+ hps.data.filter_length // 2 + 1,
26
+ n_speakers=hps.data.n_speakers,
27
+ **hps.model,
28
+ ).to(device)
29
+
30
+ model.eval()
31
+ self.model = model
32
+ self.hps = hps
33
+ self.device = device
34
+
35
+ def load_ckpt(self, ckpt_path):
36
+ checkpoint_dict = torch.load(ckpt_path, map_location=torch.device(self.device))
37
+ a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
38
+ print("Loaded checkpoint '{}'".format(ckpt_path))
39
+ print('missing/unexpected keys:', a, b)
40
+
41
+
42
+ class BaseSpeakerTTS(OpenVoiceBaseClass):
43
+ language_marks = {
44
+ "english": "EN",
45
+ "chinese": "ZH",
46
+ }
47
+
48
+ @staticmethod
49
+ def get_text(text, hps, is_symbol):
50
+ text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
51
+ if hps.data.add_blank:
52
+ text_norm = commons.intersperse(text_norm, 0)
53
+ text_norm = torch.LongTensor(text_norm)
54
+ return text_norm
55
+
56
+ @staticmethod
57
+ def audio_numpy_concat(segment_data_list, sr, speed=1.):
58
+ audio_segments = []
59
+ for segment_data in segment_data_list:
60
+ audio_segments += segment_data.reshape(-1).tolist()
61
+ audio_segments += [0] * int((sr * 0.05)/speed)
62
+ audio_segments = np.array(audio_segments).astype(np.float32)
63
+ return audio_segments
64
+
65
+ @staticmethod
66
+ def split_sentences_into_pieces(text, language_str):
67
+ texts = utils.split_sentence(text, language_str=language_str)
68
+ print(" > Text splitted to sentences.")
69
+ print('\n'.join(texts))
70
+ print(" > ===========================")
71
+ return texts
72
+
73
+ def tts(self, text, output_path, speaker, language='English', speed=1.0):
74
+ mark = self.language_marks.get(language.lower(), None)
75
+ assert mark is not None, f"language {language} is not supported"
76
+
77
+ texts = self.split_sentences_into_pieces(text, mark)
78
+
79
+ audio_list = []
80
+ for t in texts:
81
+ t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
82
+ t = f'[{mark}]{t}[{mark}]'
83
+ stn_tst = self.get_text(t, self.hps, False)
84
+ device = self.device
85
+ speaker_id = self.hps.speakers[speaker]
86
+ with torch.no_grad():
87
+ x_tst = stn_tst.unsqueeze(0).to(device)
88
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
89
+ sid = torch.LongTensor([speaker_id]).to(device)
90
+ audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6,
91
+ length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
92
+ audio_list.append(audio)
93
+ audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
94
+
95
+ if output_path is None:
96
+ return audio
97
+ else:
98
+ soundfile.write(output_path, audio, self.hps.data.sampling_rate)
99
+
100
+
101
+ class ToneColorConverter(OpenVoiceBaseClass):
102
+ def __init__(self, *args, **kwargs):
103
+ super().__init__(*args, **kwargs)
104
+
105
+ if kwargs.get('enable_watermark', True):
106
+ import wavmark
107
+ self.watermark_model = wavmark.load_model().to(self.device)
108
+ else:
109
+ self.watermark_model = None
110
+ self.version = getattr(self.hps, '_version_', "v1")
111
+
112
+
113
+
114
+ def extract_se(self, ref_wav_list, se_save_path=None):
115
+ if isinstance(ref_wav_list, str):
116
+ ref_wav_list = [ref_wav_list]
117
+
118
+ device = self.device
119
+ hps = self.hps
120
+ gs = []
121
+
122
+ for fname in ref_wav_list:
123
+ audio_ref, sr = librosa.load(fname, sr=hps.data.sampling_rate)
124
+ y = torch.FloatTensor(audio_ref)
125
+ y = y.to(device)
126
+ y = y.unsqueeze(0)
127
+ y = spectrogram_torch(y, hps.data.filter_length,
128
+ hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
129
+ center=False).to(device)
130
+ with torch.no_grad():
131
+ g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
132
+ gs.append(g.detach())
133
+ gs = torch.stack(gs).mean(0)
134
+
135
+ if se_save_path is not None:
136
+ os.makedirs(os.path.dirname(se_save_path), exist_ok=True)
137
+ torch.save(gs.cpu(), se_save_path)
138
+
139
+ return gs
140
+
141
+ def convert(self, audio_src_path, src_se, tgt_se, output_path=None, tau=0.3, message="default"):
142
+ hps = self.hps
143
+ # load audio
144
+ audio, sample_rate = librosa.load(audio_src_path, sr=hps.data.sampling_rate)
145
+ audio = torch.tensor(audio).float()
146
+
147
+ with torch.no_grad():
148
+ y = torch.FloatTensor(audio).to(self.device)
149
+ y = y.unsqueeze(0)
150
+ spec = spectrogram_torch(y, hps.data.filter_length,
151
+ hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
152
+ center=False).to(self.device)
153
+ spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device)
154
+ audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][
155
+ 0, 0].data.cpu().float().numpy()
156
+ audio = self.add_watermark(audio, message)
157
+ if output_path is None:
158
+ return audio
159
+ else:
160
+ soundfile.write(output_path, audio, hps.data.sampling_rate)
161
+
162
+ def add_watermark(self, audio, message):
163
+ if self.watermark_model is None:
164
+ return audio
165
+ device = self.device
166
+ bits = utils.string_to_bits(message).reshape(-1)
167
+ n_repeat = len(bits) // 32
168
+
169
+ K = 16000
170
+ coeff = 2
171
+ for n in range(n_repeat):
172
+ trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
173
+ if len(trunck) != K:
174
+ print('Audio too short, fail to add watermark')
175
+ break
176
+ message_npy = bits[n * 32: (n + 1) * 32]
177
+
178
+ with torch.no_grad():
179
+ signal = torch.FloatTensor(trunck).to(device)[None]
180
+ message_tensor = torch.FloatTensor(message_npy).to(device)[None]
181
+ signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor)
182
+ signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze()
183
+ audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy
184
+ return audio
185
+
186
+ def detect_watermark(self, audio, n_repeat):
187
+ bits = []
188
+ K = 16000
189
+ coeff = 2
190
+ for n in range(n_repeat):
191
+ trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
192
+ if len(trunck) != K:
193
+ print('Audio too short, fail to detect watermark')
194
+ return 'Fail'
195
+ with torch.no_grad():
196
+ signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0)
197
+ message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze()
198
+ bits.append(message_decoded_npy)
199
+ bits = np.stack(bits).reshape(-1, 8)
200
+ message = utils.bits_to_string(bits)
201
+ return message
202
+