Spaces:
Runtime error
Runtime error
XXXXRT666
commited on
Commit
ยท
c508945
1
Parent(s):
25a23a4
- AR/models/structs.py +1 -0
- AR/models/t2s_model_flash_attn.py +6 -7
- inference_webui.py +41 -3
AR/models/structs.py
CHANGED
@@ -17,6 +17,7 @@ Tensor = torch.Tensor
|
|
17 |
@dataclass
|
18 |
class T2SResult:
|
19 |
result: List[Tensor] | None = None
|
|
|
20 |
status: Literal["Success", "Error"] = "Success"
|
21 |
exception: Optional[Exception] = None
|
22 |
traceback: Optional[str] = None
|
|
|
17 |
@dataclass
|
18 |
class T2SResult:
|
19 |
result: List[Tensor] | None = None
|
20 |
+
infer_speed: float = 0.0
|
21 |
status: Literal["Success", "Error"] = "Success"
|
22 |
exception: Optional[Exception] = None
|
23 |
traceback: Optional[str] = None
|
AR/models/t2s_model_flash_attn.py
CHANGED
@@ -226,7 +226,7 @@ class CUDAGraphRunner:
|
|
226 |
self.kv_cache = decoder_model.init_cache(1)
|
227 |
self.input_pos = torch.tensor([10]).int().cuda()
|
228 |
|
229 |
-
def _handle_request(self, request: T2SRequest)
|
230 |
with self.device:
|
231 |
for i in self.kv_cache:
|
232 |
i.empty()
|
@@ -236,6 +236,7 @@ class CUDAGraphRunner:
|
|
236 |
self.input_pos.copy_(session.input_pos)
|
237 |
|
238 |
t1 = 0.0
|
|
|
239 |
y = session.y
|
240 |
bsz = y.size(0)
|
241 |
torch_profiler = TorchProfiler(request.debug)
|
@@ -314,9 +315,7 @@ class CUDAGraphRunner:
|
|
314 |
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
|
315 |
)
|
316 |
tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
317 |
-
|
318 |
-
f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s", duration=0.75
|
319 |
-
)
|
320 |
break
|
321 |
|
322 |
if (
|
@@ -362,12 +361,12 @@ class CUDAGraphRunner:
|
|
362 |
torch.mtia.empty_cache()
|
363 |
|
364 |
torch_profiler.end()
|
365 |
-
return session.y_results[: request.valid_length]
|
366 |
|
367 |
def generate(self, request: T2SRequest):
|
368 |
try:
|
369 |
-
result = self._handle_request(request)
|
370 |
-
t2s_result = T2SResult(result=result, status="Success")
|
371 |
except Exception as e:
|
372 |
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
|
373 |
return t2s_result
|
|
|
226 |
self.kv_cache = decoder_model.init_cache(1)
|
227 |
self.input_pos = torch.tensor([10]).int().cuda()
|
228 |
|
229 |
+
def _handle_request(self, request: T2SRequest):
|
230 |
with self.device:
|
231 |
for i in self.kv_cache:
|
232 |
i.empty()
|
|
|
236 |
self.input_pos.copy_(session.input_pos)
|
237 |
|
238 |
t1 = 0.0
|
239 |
+
infer_speed = 0.0
|
240 |
y = session.y
|
241 |
bsz = y.size(0)
|
242 |
torch_profiler = TorchProfiler(request.debug)
|
|
|
315 |
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
|
316 |
)
|
317 |
tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
318 |
+
infer_speed = (idx - 1) / (time.perf_counter() - t1)
|
|
|
|
|
319 |
break
|
320 |
|
321 |
if (
|
|
|
361 |
torch.mtia.empty_cache()
|
362 |
|
363 |
torch_profiler.end()
|
364 |
+
return session.y_results[: request.valid_length], infer_speed
|
365 |
|
366 |
def generate(self, request: T2SRequest):
|
367 |
try:
|
368 |
+
result, infer_speed = self._handle_request(request)
|
369 |
+
t2s_result = T2SResult(result=result, infer_speed=infer_speed, status="Success")
|
370 |
except Exception as e:
|
371 |
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
|
372 |
return t2s_result
|
inference_webui.py
CHANGED
@@ -519,6 +519,8 @@ def get_tts_wav(
|
|
519 |
if not ref_free:
|
520 |
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
521 |
|
|
|
|
|
522 |
for i_text, text in enumerate(texts):
|
523 |
# ่งฃๅณ่พๅ
ฅ็ฎๆ ๆๆฌ็็ฉบ่กๅฏผ่ดๆฅ้็้ฎ้ข
|
524 |
if len(text.strip()) == 0:
|
@@ -559,11 +561,15 @@ def get_tts_wav(
|
|
559 |
# debug=True,
|
560 |
)
|
561 |
t2s_result = t2s_model.generate(t2s_request)
|
562 |
-
|
563 |
-
if
|
564 |
print(t2s_result.exception)
|
565 |
print(t2s_result.traceback)
|
566 |
raise RuntimeError("")
|
|
|
|
|
|
|
|
|
567 |
cache[i_text] = pred_semantic
|
568 |
t3 = ttime()
|
569 |
refers = []
|
@@ -601,6 +607,7 @@ def get_tts_wav(
|
|
601 |
t.extend([t2 - t1, t3 - t2, t4 - t3])
|
602 |
t1 = ttime()
|
603 |
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
|
|
|
604 |
gr.Info("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])), duration=4)
|
605 |
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
606 |
|
@@ -735,7 +742,7 @@ def html_left(text, label="p"):
|
|
735 |
</div>"""
|
736 |
|
737 |
|
738 |
-
with gr.Blocks(title="GPT-SoVITS WebUI", theme=themes.
|
739 |
gr.Markdown(
|
740 |
value="""# GPT-SoVITS-ProPlus Zero-shot TTS demo
|
741 |
## https://github.com/RVC-Boss/GPT-SoVITS
|
@@ -837,8 +844,39 @@ This demo is open source under the MIT license. The author does not have any con
|
|
837 |
)
|
838 |
|
839 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
840 |
app.queue().launch(
|
841 |
server_name="0.0.0.0",
|
842 |
inbrowser=True,
|
843 |
show_api=False,
|
|
|
844 |
)
|
|
|
519 |
if not ref_free:
|
520 |
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
521 |
|
522 |
+
infer_speed: list[float] = []
|
523 |
+
|
524 |
for i_text, text in enumerate(texts):
|
525 |
# ่งฃๅณ่พๅ
ฅ็ฎๆ ๆๆฌ็็ฉบ่กๅฏผ่ดๆฅ้็้ฎ้ข
|
526 |
if len(text.strip()) == 0:
|
|
|
561 |
# debug=True,
|
562 |
)
|
563 |
t2s_result = t2s_model.generate(t2s_request)
|
564 |
+
|
565 |
+
if t2s_result.exception is not None:
|
566 |
print(t2s_result.exception)
|
567 |
print(t2s_result.traceback)
|
568 |
raise RuntimeError("")
|
569 |
+
|
570 |
+
infer_speed.append(t2s_result.infer_speed)
|
571 |
+
pred_semantic = t2s_result.result
|
572 |
+
assert pred_semantic
|
573 |
cache[i_text] = pred_semantic
|
574 |
t3 = ttime()
|
575 |
refers = []
|
|
|
607 |
t.extend([t2 - t1, t3 - t2, t4 - t3])
|
608 |
t1 = ttime()
|
609 |
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
|
610 |
+
gr.Info(f"Infer Speed: {sum(infer_speed) / len(infer_speed):.2f} Token/s")
|
611 |
gr.Info("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])), duration=4)
|
612 |
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
613 |
|
|
|
742 |
</div>"""
|
743 |
|
744 |
|
745 |
+
with gr.Blocks(title="GPT-SoVITS WebUI", theme=themes.Soft(), analytics_enabled=False) as app:
|
746 |
gr.Markdown(
|
747 |
value="""# GPT-SoVITS-ProPlus Zero-shot TTS demo
|
748 |
## https://github.com/RVC-Boss/GPT-SoVITS
|
|
|
844 |
)
|
845 |
|
846 |
if __name__ == "__main__":
|
847 |
+
import tempfile
|
848 |
+
import wave
|
849 |
+
|
850 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_file:
|
851 |
+
file_name = temp_file.name
|
852 |
+
with wave.open(temp_file, "w") as wav_file:
|
853 |
+
channels = 1
|
854 |
+
sample_width = 2
|
855 |
+
sample_rate = 44100
|
856 |
+
duration = 5
|
857 |
+
frequency = 440.0
|
858 |
+
|
859 |
+
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
|
860 |
+
sine_wave = np.sin(2 * np.pi * frequency * t) # Sine Wave
|
861 |
+
int_wave = (sine_wave * 32767).astype(np.int16)
|
862 |
+
|
863 |
+
wav_file.setnchannels(channels) # pylint: disable=no-member
|
864 |
+
wav_file.setsampwidth(sample_width) # pylint: disable=no-member
|
865 |
+
wav_file.setframerate(sample_rate) # pylint: disable=no-member
|
866 |
+
wav_file.writeframes(int_wave.tobytes()) # pylint: disable=no-member
|
867 |
+
|
868 |
+
gen = get_tts_wav(
|
869 |
+
ref_wav_path=file_name,
|
870 |
+
prompt_text="",
|
871 |
+
prompt_language=i18n("ไธญๆ"),
|
872 |
+
text="็ฏๅคงๅด็ๅ่
,็ๅฟ
ๅป่็ ดไน,็ฏๅคงๅด็ๅ่
,็ๅฟ
ๅป่็ ดไน,็ฏๅคงๅด็ๅ่
,็ๅฟ
ๅป่็ ดไน,็ฏๅคงๅด็ๅ่
,็ๅฟ
ๅป่็ ดไน",
|
873 |
+
text_language=i18n("ไธญๆ"),
|
874 |
+
)
|
875 |
+
next(gen)
|
876 |
+
|
877 |
app.queue().launch(
|
878 |
server_name="0.0.0.0",
|
879 |
inbrowser=True,
|
880 |
show_api=False,
|
881 |
+
allowed_paths=["/"],
|
882 |
)
|