kevinwang676 commited on
Commit
9ea55a5
·
verified ·
1 Parent(s): f1e44ab

Update GPT_SoVITS/app.py

Browse files
Files changed (1) hide show
  1. GPT_SoVITS/app.py +7 -3
GPT_SoVITS/app.py CHANGED
@@ -428,16 +428,20 @@ def vc_main(wav_path, text, language, prompt_wav, noise_scale=0.5):
428
  phones, word2ph, norm_text = get_cleaned_text_final(text, language)
429
 
430
  spec = get_spepc(hps, prompt_wav)
431
- codes = get_code_from_wav(wav_path)[None, None] # 必须是 3D, [n_q, B, T]
 
432
  ge = vq_model.ref_enc(spec) # [B, D, T/1]
433
  quantized = vq_model.quantizer.decode(codes) # [B, D, T]
434
  if hps.model.semantic_frame_rate == "25hz":
435
  quantized = F.interpolate(
436
  quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
437
  )
 
 
 
 
438
  _, m_p, logs_p, y_mask = vq_model.enc_p(
439
- quantized, torch.LongTensor([quantized.shape[-1]]),
440
- torch.LongTensor(phones)[None], torch.LongTensor([len(phones)]), ge
441
  )
442
  z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
443
  z = vq_model.flow(z_p, y_mask, g=ge, reverse=True)
 
428
  phones, word2ph, norm_text = get_cleaned_text_final(text, language)
429
 
430
  spec = get_spepc(hps, prompt_wav)
431
+ spec = spec.to(device)
432
+ codes = get_code_from_wav(wav_path)[None, None].to(device) # 必须是 3D, [n_q, B, T]
433
  ge = vq_model.ref_enc(spec) # [B, D, T/1]
434
  quantized = vq_model.quantizer.decode(codes) # [B, D, T]
435
  if hps.model.semantic_frame_rate == "25hz":
436
  quantized = F.interpolate(
437
  quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
438
  )
439
+ lengths_tensor = torch.LongTensor([quantized.shape[-1]]).to(device)
440
+ phones_tensor = torch.LongTensor(phones)[None].to(device)
441
+ phones_lengths_tensor = torch.LongTensor([len(phones)]).to(device)
442
+
443
  _, m_p, logs_p, y_mask = vq_model.enc_p(
444
+ quantized, lengths_tensor, phones_tensor, phones_lengths_tensor, ge
 
445
  )
446
  z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
447
  z = vq_model.flow(z_p, y_mask, g=ge, reverse=True)