kevinwang676 commited on
Commit
f1e44ab
·
verified ·
1 Parent(s): 072e72d

Update GPT_SoVITS/app_colab.py

Browse files
Files changed (1) hide show
  1. GPT_SoVITS/app_colab.py +9 -5
GPT_SoVITS/app_colab.py CHANGED
@@ -140,7 +140,7 @@ else:
140
 
141
  def change_sovits_weights(sovits_path):
142
  global vq_model, hps
143
- dict_s2 = torch.load(sovits_path, map_location="cpu")
144
  hps = dict_s2["config"]
145
  hps = DictToAttrRecursive(hps)
146
  hps.model.semantic_frame_rate = "25hz"
@@ -168,7 +168,7 @@ change_sovits_weights(sovits_path)
168
  def change_gpt_weights(gpt_path):
169
  global hz, max_sec, t2s_model, config
170
  hz = 50
171
- dict_s1 = torch.load(gpt_path, map_location="cpu")
172
  config = dict_s1["config"]
173
  max_sec = config["data"]["max_sec"]
174
  t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
@@ -426,16 +426,20 @@ def vc_main(wav_path, text, language, prompt_wav, noise_scale=0.5):
426
  phones, word2ph, norm_text = get_cleaned_text_final(text, language)
427
 
428
  spec = get_spepc(hps, prompt_wav)
429
- codes = get_code_from_wav(wav_path)[None, None] # 必须是 3D, [n_q, B, T]
 
430
  ge = vq_model.ref_enc(spec) # [B, D, T/1]
431
  quantized = vq_model.quantizer.decode(codes) # [B, D, T]
432
  if hps.model.semantic_frame_rate == "25hz":
433
  quantized = F.interpolate(
434
  quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
435
  )
 
 
 
 
436
  _, m_p, logs_p, y_mask = vq_model.enc_p(
437
- quantized, torch.LongTensor([quantized.shape[-1]]),
438
- torch.LongTensor(phones)[None], torch.LongTensor([len(phones)]), ge
439
  )
440
  z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
441
  z = vq_model.flow(z_p, y_mask, g=ge, reverse=True)
 
140
 
141
  def change_sovits_weights(sovits_path):
142
  global vq_model, hps
143
+ dict_s2 = torch.load(sovits_path)
144
  hps = dict_s2["config"]
145
  hps = DictToAttrRecursive(hps)
146
  hps.model.semantic_frame_rate = "25hz"
 
168
  def change_gpt_weights(gpt_path):
169
  global hz, max_sec, t2s_model, config
170
  hz = 50
171
+ dict_s1 = torch.load(gpt_path)
172
  config = dict_s1["config"]
173
  max_sec = config["data"]["max_sec"]
174
  t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
 
426
  phones, word2ph, norm_text = get_cleaned_text_final(text, language)
427
 
428
  spec = get_spepc(hps, prompt_wav)
429
+ spec = spec.to(device)
430
+ codes = get_code_from_wav(wav_path)[None, None].to(device) # 必须是 3D, [n_q, B, T]
431
  ge = vq_model.ref_enc(spec) # [B, D, T/1]
432
  quantized = vq_model.quantizer.decode(codes) # [B, D, T]
433
  if hps.model.semantic_frame_rate == "25hz":
434
  quantized = F.interpolate(
435
  quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
436
  )
437
+ lengths_tensor = torch.LongTensor([quantized.shape[-1]]).to(device)
438
+ phones_tensor = torch.LongTensor(phones)[None].to(device)
439
+ phones_lengths_tensor = torch.LongTensor([len(phones)]).to(device)
440
+
441
  _, m_p, logs_p, y_mask = vq_model.enc_p(
442
+ quantized, lengths_tensor, phones_tensor, phones_lengths_tensor, ge
 
443
  )
444
  z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
445
  z = vq_model.flow(z_p, y_mask, g=ge, reverse=True)