XXXXRT666 commited on
Commit
c508945
·
1 Parent(s): 25a23a4
AR/models/structs.py CHANGED
@@ -17,6 +17,7 @@ Tensor = torch.Tensor
17
  @dataclass
18
  class T2SResult:
19
  result: List[Tensor] | None = None
 
20
  status: Literal["Success", "Error"] = "Success"
21
  exception: Optional[Exception] = None
22
  traceback: Optional[str] = None
 
17
  @dataclass
18
  class T2SResult:
19
  result: List[Tensor] | None = None
20
+ infer_speed: float = 0.0
21
  status: Literal["Success", "Error"] = "Success"
22
  exception: Optional[Exception] = None
23
  traceback: Optional[str] = None
AR/models/t2s_model_flash_attn.py CHANGED
@@ -226,7 +226,7 @@ class CUDAGraphRunner:
226
  self.kv_cache = decoder_model.init_cache(1)
227
  self.input_pos = torch.tensor([10]).int().cuda()
228
 
229
- def _handle_request(self, request: T2SRequest) -> List[torch.Tensor]:
230
  with self.device:
231
  for i in self.kv_cache:
232
  i.empty()
@@ -236,6 +236,7 @@ class CUDAGraphRunner:
236
  self.input_pos.copy_(session.input_pos)
237
 
238
  t1 = 0.0
 
239
  y = session.y
240
  bsz = y.size(0)
241
  torch_profiler = TorchProfiler(request.debug)
@@ -314,9 +315,7 @@ class CUDAGraphRunner:
314
  f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
315
  )
316
  tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
317
- gr.Info(
318
- f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s", duration=0.75
319
- )
320
  break
321
 
322
  if (
@@ -362,12 +361,12 @@ class CUDAGraphRunner:
362
  torch.mtia.empty_cache()
363
 
364
  torch_profiler.end()
365
- return session.y_results[: request.valid_length]
366
 
367
  def generate(self, request: T2SRequest):
368
  try:
369
- result = self._handle_request(request)
370
- t2s_result = T2SResult(result=result, status="Success")
371
  except Exception as e:
372
  t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
373
  return t2s_result
 
226
  self.kv_cache = decoder_model.init_cache(1)
227
  self.input_pos = torch.tensor([10]).int().cuda()
228
 
229
+ def _handle_request(self, request: T2SRequest):
230
  with self.device:
231
  for i in self.kv_cache:
232
  i.empty()
 
236
  self.input_pos.copy_(session.input_pos)
237
 
238
  t1 = 0.0
239
+ infer_speed = 0.0
240
  y = session.y
241
  bsz = y.size(0)
242
  torch_profiler = TorchProfiler(request.debug)
 
315
  f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
316
  )
317
  tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
318
+ infer_speed = (idx - 1) / (time.perf_counter() - t1)
 
 
319
  break
320
 
321
  if (
 
361
  torch.mtia.empty_cache()
362
 
363
  torch_profiler.end()
364
+ return session.y_results[: request.valid_length], infer_speed
365
 
366
  def generate(self, request: T2SRequest):
367
  try:
368
+ result, infer_speed = self._handle_request(request)
369
+ t2s_result = T2SResult(result=result, infer_speed=infer_speed, status="Success")
370
  except Exception as e:
371
  t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
372
  return t2s_result
inference_webui.py CHANGED
@@ -519,6 +519,8 @@ def get_tts_wav(
519
  if not ref_free:
520
  phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
521
 
 
 
522
  for i_text, text in enumerate(texts):
523
  # 解决输入目标文本的空行导致报错的问题
524
  if len(text.strip()) == 0:
@@ -559,11 +561,15 @@ def get_tts_wav(
559
  # debug=True,
560
  )
561
  t2s_result = t2s_model.generate(t2s_request)
562
- pred_semantic = t2s_result.result
563
- if pred_semantic is None:
564
  print(t2s_result.exception)
565
  print(t2s_result.traceback)
566
  raise RuntimeError("")
 
 
 
 
567
  cache[i_text] = pred_semantic
568
  t3 = ttime()
569
  refers = []
@@ -601,6 +607,7 @@ def get_tts_wav(
601
  t.extend([t2 - t1, t3 - t2, t4 - t3])
602
  t1 = ttime()
603
  print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
 
604
  gr.Info("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])), duration=4)
605
  yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
606
 
@@ -735,7 +742,7 @@ def html_left(text, label="p"):
735
  </div>"""
736
 
737
 
738
- with gr.Blocks(title="GPT-SoVITS WebUI", theme=themes.Monochrome(), analytics_enabled=False) as app:
739
  gr.Markdown(
740
  value="""# GPT-SoVITS-ProPlus Zero-shot TTS demo
741
  ## https://github.com/RVC-Boss/GPT-SoVITS
@@ -837,8 +844,39 @@ This demo is open source under the MIT license. The author does not have any con
837
  )
838
 
839
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
840
  app.queue().launch(
841
  server_name="0.0.0.0",
842
  inbrowser=True,
843
  show_api=False,
 
844
  )
 
519
  if not ref_free:
520
  phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
521
 
522
+ infer_speed: list[float] = []
523
+
524
  for i_text, text in enumerate(texts):
525
  # 解决输入目标文本的空行导致报错的问题
526
  if len(text.strip()) == 0:
 
561
  # debug=True,
562
  )
563
  t2s_result = t2s_model.generate(t2s_request)
564
+
565
+ if t2s_result.exception is not None:
566
  print(t2s_result.exception)
567
  print(t2s_result.traceback)
568
  raise RuntimeError("")
569
+
570
+ infer_speed.append(t2s_result.infer_speed)
571
+ pred_semantic = t2s_result.result
572
+ assert pred_semantic
573
  cache[i_text] = pred_semantic
574
  t3 = ttime()
575
  refers = []
 
607
  t.extend([t2 - t1, t3 - t2, t4 - t3])
608
  t1 = ttime()
609
  print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
610
+ gr.Info(f"Infer Speed: {sum(infer_speed) / len(infer_speed):.2f} Token/s")
611
  gr.Info("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])), duration=4)
612
  yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
613
 
 
742
  </div>"""
743
 
744
 
745
+ with gr.Blocks(title="GPT-SoVITS WebUI", theme=themes.Soft(), analytics_enabled=False) as app:
746
  gr.Markdown(
747
  value="""# GPT-SoVITS-ProPlus Zero-shot TTS demo
748
  ## https://github.com/RVC-Boss/GPT-SoVITS
 
844
  )
845
 
846
  if __name__ == "__main__":
847
+ import tempfile
848
+ import wave
849
+
850
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_file:
851
+ file_name = temp_file.name
852
+ with wave.open(temp_file, "w") as wav_file:
853
+ channels = 1
854
+ sample_width = 2
855
+ sample_rate = 44100
856
+ duration = 5
857
+ frequency = 440.0
858
+
859
+ t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
860
+ sine_wave = np.sin(2 * np.pi * frequency * t) # Sine Wave
861
+ int_wave = (sine_wave * 32767).astype(np.int16)
862
+
863
+ wav_file.setnchannels(channels) # pylint: disable=no-member
864
+ wav_file.setsampwidth(sample_width) # pylint: disable=no-member
865
+ wav_file.setframerate(sample_rate) # pylint: disable=no-member
866
+ wav_file.writeframes(int_wave.tobytes()) # pylint: disable=no-member
867
+
868
+ gen = get_tts_wav(
869
+ ref_wav_path=file_name,
870
+ prompt_text="",
871
+ prompt_language=i18n("中文"),
872
+ text="犯大吴疆土者,盛必击而破之,犯大吴疆土者,盛必击而破之,犯大吴疆土者,盛必击而破之,犯大吴疆土者,盛必击而破之",
873
+ text_language=i18n("中文"),
874
+ )
875
+ next(gen)
876
+
877
  app.queue().launch(
878
  server_name="0.0.0.0",
879
  inbrowser=True,
880
  show_api=False,
881
+ allowed_paths=["/"],
882
  )