Spaces:
Runtime error
Runtime error
Update GPT_SoVITS/app_colab.py
Browse files- GPT_SoVITS/app_colab.py +9 -5
GPT_SoVITS/app_colab.py
CHANGED
@@ -140,7 +140,7 @@ else:
|
|
140 |
|
141 |
def change_sovits_weights(sovits_path):
|
142 |
global vq_model, hps
|
143 |
-
dict_s2 = torch.load(sovits_path
|
144 |
hps = dict_s2["config"]
|
145 |
hps = DictToAttrRecursive(hps)
|
146 |
hps.model.semantic_frame_rate = "25hz"
|
@@ -168,7 +168,7 @@ change_sovits_weights(sovits_path)
|
|
168 |
def change_gpt_weights(gpt_path):
|
169 |
global hz, max_sec, t2s_model, config
|
170 |
hz = 50
|
171 |
-
dict_s1 = torch.load(gpt_path
|
172 |
config = dict_s1["config"]
|
173 |
max_sec = config["data"]["max_sec"]
|
174 |
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
@@ -426,16 +426,20 @@ def vc_main(wav_path, text, language, prompt_wav, noise_scale=0.5):
|
|
426 |
phones, word2ph, norm_text = get_cleaned_text_final(text, language)
|
427 |
|
428 |
spec = get_spepc(hps, prompt_wav)
|
429 |
-
|
|
|
430 |
ge = vq_model.ref_enc(spec) # [B, D, T/1]
|
431 |
quantized = vq_model.quantizer.decode(codes) # [B, D, T]
|
432 |
if hps.model.semantic_frame_rate == "25hz":
|
433 |
quantized = F.interpolate(
|
434 |
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
435 |
)
|
|
|
|
|
|
|
|
|
436 |
_, m_p, logs_p, y_mask = vq_model.enc_p(
|
437 |
-
quantized,
|
438 |
-
torch.LongTensor(phones)[None], torch.LongTensor([len(phones)]), ge
|
439 |
)
|
440 |
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
441 |
z = vq_model.flow(z_p, y_mask, g=ge, reverse=True)
|
|
|
140 |
|
141 |
def change_sovits_weights(sovits_path):
|
142 |
global vq_model, hps
|
143 |
+
dict_s2 = torch.load(sovits_path)
|
144 |
hps = dict_s2["config"]
|
145 |
hps = DictToAttrRecursive(hps)
|
146 |
hps.model.semantic_frame_rate = "25hz"
|
|
|
168 |
def change_gpt_weights(gpt_path):
|
169 |
global hz, max_sec, t2s_model, config
|
170 |
hz = 50
|
171 |
+
dict_s1 = torch.load(gpt_path)
|
172 |
config = dict_s1["config"]
|
173 |
max_sec = config["data"]["max_sec"]
|
174 |
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
|
|
426 |
phones, word2ph, norm_text = get_cleaned_text_final(text, language)
|
427 |
|
428 |
spec = get_spepc(hps, prompt_wav)
|
429 |
+
spec = spec.to(device)
|
430 |
+
codes = get_code_from_wav(wav_path)[None, None].to(device) # 必须是 3D, [n_q, B, T]
|
431 |
ge = vq_model.ref_enc(spec) # [B, D, T/1]
|
432 |
quantized = vq_model.quantizer.decode(codes) # [B, D, T]
|
433 |
if hps.model.semantic_frame_rate == "25hz":
|
434 |
quantized = F.interpolate(
|
435 |
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
436 |
)
|
437 |
+
lengths_tensor = torch.LongTensor([quantized.shape[-1]]).to(device)
|
438 |
+
phones_tensor = torch.LongTensor(phones)[None].to(device)
|
439 |
+
phones_lengths_tensor = torch.LongTensor([len(phones)]).to(device)
|
440 |
+
|
441 |
_, m_p, logs_p, y_mask = vq_model.enc_p(
|
442 |
+
quantized, lengths_tensor, phones_tensor, phones_lengths_tensor, ge
|
|
|
443 |
)
|
444 |
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
445 |
z = vq_model.flow(z_p, y_mask, g=ge, reverse=True)
|