Update api_v2.py
Browse files
api_v2.py
CHANGED
@@ -129,14 +129,23 @@ cut_method_names = get_cut_method_names()
|
|
129 |
import os
|
130 |
import sys
|
131 |
import traceback
|
132 |
-
from typing import Generator
|
133 |
import requests
|
134 |
import tempfile
|
135 |
import urllib.parse
|
136 |
from pathlib import Path
|
137 |
|
138 |
# Function to check if a path is a URL and download it if needed
|
139 |
-
def process_audio_path(audio_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
if audio_path and (audio_path.startswith('http://') or audio_path.startswith('https://') or
|
141 |
audio_path.startswith('s3://')):
|
142 |
try:
|
@@ -175,13 +184,13 @@ def process_audio_path(audio_path):
|
|
175 |
f.write(chunk)
|
176 |
|
177 |
print(f"Downloaded to: {local_path}")
|
178 |
-
return local_path
|
179 |
except Exception as e:
|
180 |
print(f"Error downloading audio file: {e}")
|
181 |
raise Exception(f"Failed to download audio from URL: {e}")
|
182 |
|
183 |
# If not a URL or download failed, return the original path
|
184 |
-
return audio_path
|
185 |
|
186 |
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
187 |
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
|
@@ -332,40 +341,14 @@ def check_params(req:dict):
|
|
332 |
async def tts_handle(req:dict):
|
333 |
"""
|
334 |
Text to speech handler.
|
335 |
-
|
336 |
-
Args:
|
337 |
-
req (dict):
|
338 |
-
{
|
339 |
-
"text": "", # str.(required) text to be synthesized
|
340 |
-
"text_lang: "", # str.(required) language of the text to be synthesized
|
341 |
-
"ref_audio_path": "", # str.(required) reference audio path or URL
|
342 |
-
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths or URLs
|
343 |
-
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
344 |
-
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
345 |
-
"top_k": 5, # int. top k sampling
|
346 |
-
"top_p": 1, # float. top p sampling
|
347 |
-
"temperature": 1, # float. temperature for sampling
|
348 |
-
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
349 |
-
"batch_size": 1, # int. batch size for inference
|
350 |
-
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
351 |
-
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
352 |
-
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
353 |
-
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
354 |
-
"seed": -1, # int. random seed for reproducibility.
|
355 |
-
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
|
356 |
-
"streaming_mode": False, # bool. whether to return a streaming response.
|
357 |
-
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
|
358 |
-
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
359 |
-
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
360 |
-
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
361 |
-
}
|
362 |
-
returns:
|
363 |
-
StreamingResponse: audio stream response.
|
364 |
"""
|
365 |
|
366 |
streaming_mode = req.get("streaming_mode", False)
|
367 |
return_fragment = req.get("return_fragment", False)
|
368 |
media_type = req.get("media_type", "wav")
|
|
|
|
|
|
|
369 |
|
370 |
check_res = check_params(req)
|
371 |
if check_res is not None:
|
@@ -376,36 +359,69 @@ async def tts_handle(req:dict):
|
|
376 |
|
377 |
try:
|
378 |
# Process ref_audio_path (download if it's a URL)
|
379 |
-
|
|
|
|
|
|
|
380 |
|
381 |
# Process aux_ref_audio_paths (download if they're URLs)
|
382 |
if req.get("aux_ref_audio_paths"):
|
383 |
aux_paths = []
|
384 |
for aux_path in req["aux_ref_audio_paths"]:
|
385 |
-
|
|
|
|
|
|
|
386 |
req["aux_ref_audio_paths"] = aux_paths
|
387 |
|
388 |
-
tts_generator=tts_pipeline.run(req)
|
389 |
|
390 |
if streaming_mode:
|
391 |
-
def streaming_generator(tts_generator:Generator, media_type:str):
|
392 |
if_frist_chunk = True
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
else:
|
403 |
sr, audio_data = next(tts_generator)
|
404 |
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
return Response(audio_data, media_type=f"audio/{media_type}")
|
406 |
except Exception as e:
|
407 |
-
|
408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
|
410 |
|
411 |
|
@@ -479,13 +495,32 @@ async def tts_post_endpoint(request: TTS_Request):
|
|
479 |
|
480 |
@APP.get("/set_refer_audio")
|
481 |
async def set_refer_aduio(refer_audio_path: str = None):
|
|
|
482 |
try:
|
483 |
# Process the path (download if it's a URL)
|
484 |
-
local_path = process_audio_path(refer_audio_path)
|
|
|
|
|
|
|
|
|
485 |
tts_pipeline.set_ref_audio(local_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
|
488 |
-
return JSONResponse(status_code=200, content={"message": "success"})
|
489 |
|
490 |
|
491 |
# @APP.post("/set_refer_audio")
|
|
|
129 |
import os
|
130 |
import sys
|
131 |
import traceback
|
132 |
+
from typing import Generator, Tuple
|
133 |
import requests
|
134 |
import tempfile
|
135 |
import urllib.parse
|
136 |
from pathlib import Path
|
137 |
|
138 |
# Function to check if a path is a URL and download it if needed
|
139 |
+
def process_audio_path(audio_path) -> Tuple[str, bool]:
|
140 |
+
"""
|
141 |
+
Process an audio path, downloading it if it's a URL.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
audio_path (str): Path or URL to audio file
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
Tuple[str, bool]: (local_path, is_temporary)
|
148 |
+
"""
|
149 |
if audio_path and (audio_path.startswith('http://') or audio_path.startswith('https://') or
|
150 |
audio_path.startswith('s3://')):
|
151 |
try:
|
|
|
184 |
f.write(chunk)
|
185 |
|
186 |
print(f"Downloaded to: {local_path}")
|
187 |
+
return local_path, True # Return path and flag indicating it's temporary
|
188 |
except Exception as e:
|
189 |
print(f"Error downloading audio file: {e}")
|
190 |
raise Exception(f"Failed to download audio from URL: {e}")
|
191 |
|
192 |
# If not a URL or download failed, return the original path
|
193 |
+
return audio_path, False # Not a temporary file
|
194 |
|
195 |
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
196 |
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
|
|
|
341 |
async def tts_handle(req:dict):
|
342 |
"""
|
343 |
Text to speech handler.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
"""
|
345 |
|
346 |
streaming_mode = req.get("streaming_mode", False)
|
347 |
return_fragment = req.get("return_fragment", False)
|
348 |
media_type = req.get("media_type", "wav")
|
349 |
+
temp_files = [] # Track temporary files for cleanup
|
350 |
+
|
351 |
+
print(f"----------现在使用的模型版本是:{tts_config.version}----------")
|
352 |
|
353 |
check_res = check_params(req)
|
354 |
if check_res is not None:
|
|
|
359 |
|
360 |
try:
|
361 |
# Process ref_audio_path (download if it's a URL)
|
362 |
+
ref_path, is_temp = process_audio_path(req["ref_audio_path"])
|
363 |
+
req["ref_audio_path"] = ref_path
|
364 |
+
if is_temp:
|
365 |
+
temp_files.append(ref_path)
|
366 |
|
367 |
# Process aux_ref_audio_paths (download if they're URLs)
|
368 |
if req.get("aux_ref_audio_paths"):
|
369 |
aux_paths = []
|
370 |
for aux_path in req["aux_ref_audio_paths"]:
|
371 |
+
local_path, is_temp = process_audio_path(aux_path)
|
372 |
+
aux_paths.append(local_path)
|
373 |
+
if is_temp:
|
374 |
+
temp_files.append(local_path)
|
375 |
req["aux_ref_audio_paths"] = aux_paths
|
376 |
|
377 |
+
tts_generator = tts_pipeline.run(req)
|
378 |
|
379 |
if streaming_mode:
|
380 |
+
async def streaming_generator(tts_generator:Generator, media_type:str):
|
381 |
if_frist_chunk = True
|
382 |
+
try:
|
383 |
+
for sr, chunk in tts_generator:
|
384 |
+
if if_frist_chunk and media_type == "wav":
|
385 |
+
yield wave_header_chunk(sample_rate=sr)
|
386 |
+
media_type = "raw"
|
387 |
+
if_frist_chunk = False
|
388 |
+
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
|
389 |
+
finally:
|
390 |
+
# Clean up temporary files after streaming completes
|
391 |
+
for temp_file in temp_files:
|
392 |
+
try:
|
393 |
+
if os.path.exists(temp_file):
|
394 |
+
os.remove(temp_file)
|
395 |
+
print(f"Removed temporary file: {temp_file}")
|
396 |
+
except Exception as e:
|
397 |
+
print(f"Error removing temporary file {temp_file}: {e}")
|
398 |
+
|
399 |
+
return StreamingResponse(streaming_generator(tts_generator, media_type), media_type=f"audio/{media_type}")
|
400 |
else:
|
401 |
sr, audio_data = next(tts_generator)
|
402 |
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
403 |
+
|
404 |
+
# Clean up temporary files after generation completes
|
405 |
+
for temp_file in temp_files:
|
406 |
+
try:
|
407 |
+
if os.path.exists(temp_file):
|
408 |
+
os.remove(temp_file)
|
409 |
+
print(f"Removed temporary file: {temp_file}")
|
410 |
+
except Exception as e:
|
411 |
+
print(f"Error removing temporary file {temp_file}: {e}")
|
412 |
+
|
413 |
return Response(audio_data, media_type=f"audio/{media_type}")
|
414 |
except Exception as e:
|
415 |
+
# Clean up temporary files in case of error
|
416 |
+
for temp_file in temp_files:
|
417 |
+
try:
|
418 |
+
if os.path.exists(temp_file):
|
419 |
+
os.remove(temp_file)
|
420 |
+
print(f"Removed temporary file: {temp_file}")
|
421 |
+
except Exception as cleanup_error:
|
422 |
+
print(f"Error removing temporary file {temp_file}: {cleanup_error}")
|
423 |
+
|
424 |
+
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
|
425 |
|
426 |
|
427 |
|
|
|
495 |
|
496 |
@APP.get("/set_refer_audio")
|
497 |
async def set_refer_aduio(refer_audio_path: str = None):
|
498 |
+
temp_file = None
|
499 |
try:
|
500 |
# Process the path (download if it's a URL)
|
501 |
+
local_path, is_temp = process_audio_path(refer_audio_path)
|
502 |
+
if is_temp:
|
503 |
+
temp_file = local_path
|
504 |
+
|
505 |
+
# Store reference to the audio
|
506 |
tts_pipeline.set_ref_audio(local_path)
|
507 |
+
|
508 |
+
# If temporary, remove after setting (since TTS pipeline should load the audio into memory)
|
509 |
+
if temp_file and os.path.exists(temp_file):
|
510 |
+
os.remove(temp_file)
|
511 |
+
print(f"Removed temporary file: {temp_file}")
|
512 |
+
|
513 |
+
return JSONResponse(status_code=200, content={"message": "success"})
|
514 |
except Exception as e:
|
515 |
+
# Clean up temp file in case of error
|
516 |
+
if temp_file and os.path.exists(temp_file):
|
517 |
+
try:
|
518 |
+
os.remove(temp_file)
|
519 |
+
print(f"Removed temporary file: {temp_file}")
|
520 |
+
except Exception as cleanup_error:
|
521 |
+
print(f"Error removing temporary file {temp_file}: {cleanup_error}")
|
522 |
+
|
523 |
return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
|
|
|
524 |
|
525 |
|
526 |
# @APP.post("/set_refer_audio")
|